spuco.group_inference

Group Inference Methods.

Cluster

class ClusterAlg(value)

Bases: Enum

An enumeration.

KMEANS = 'kmeans'
KMEDOIDS = 'kmedoids'
class Cluster(Z: Tensor, class_labels: List[int] | None = None, cluster_alg: ClusterAlg = ClusterAlg.KMEANS, num_clusters: int = -1, max_clusters: int = -1, device: device = device(type='cpu'), verbose: bool = False)

Bases: BaseGroupInference

Clustering-based Group Inference

__init__(Z: Tensor, class_labels: List[int] | None = None, cluster_alg: ClusterAlg = ClusterAlg.KMEANS, num_clusters: int = -1, max_clusters: int = -1, device: device = device(type='cpu'), verbose: bool = False)

Initializes the Cluster object.

Parameters:
  • Z (torch.Tensor) – The input tensor for clustering.

  • class_labels (Optional[List[int]], optional) – Optional list of class labels for class-wise clustering. Defaults to None.

  • cluster_alg (ClusterAlg, optional) – The clustering algorithm to use. Defaults to ClusterAlg.KMEANS.

  • num_clusters (int, optional) – The number of clusters to create. Defaults to -1.

  • max_clusters (int, optional) – The maximum number of clusters to consider. Defaults to -1.

  • device (torch.device, optional) – The device to run the clustering on. Defaults to torch.device(“cpu”).

  • verbose (bool, optional) – Whether to display progress and logging information. Defaults to False.

infer_groups() Dict[int, List[int]]

Infers the group partition based on the clustering results.

Returns:

The group partition where each key is a cluster label and the value is a list of indices belonging to that cluster.

Return type:

Dict[int, List[int]]

silhouette(Z)

Uses the silhouette score to determine the optimal number of clusters and perform clustering based on self.cluster_alg.

Parameters:

Z (torch.Tensor) – The input data for clustering.

Returns:

The cluster partition based on the optimal number of clusters.

Return type:

List[int]

kmeans(Z, num_clusters: int = -1)

Performs K-means clustering on the input data.

Parameters:
  • Z (torch.Tensor) – The input data for clustering.

  • num_clusters (int, optional) – The number of clusters to create. If not specified, the value from the object will be used.

Returns:

The cluster labels and partition based on the K-means clustering.

Return type:

Tuple[np.ndarray, List[List[int]]]

kmedoids(Z, similiarity_matrix: Tensor, num_clusters=-1)

Performs K-medoids clustering on the input data.

Parameters:
  • Z (torch.Tensor) – The input data for clustering.

  • similiarity_matrix (torch.Tensor) – The similarity matrix for pairwise similarities between data points.

  • num_clusters (int, optional) – The number of clusters to create. If not specified, the value from the object will be used.

Returns:

The cluster labels and partition based on the K-medoids clustering.

Return type:

Tuple[np.ndarray, List[List[int]]]

Just Train Twice (JTT) Inference

class JTTInference(predictions: List[int], class_labels: List[int])

Bases: BaseGroupInference

Just Train Twice Inference: https://arxiv.org/abs/2107.09044

__init__(predictions: List[int], class_labels: List[int])

Initializes JTTInference.

Parameters:
  • predictions – List of predicted labels.

  • class_labels – List of true class labels.

infer_groups() Dict[Tuple[int, int], List[int]]

Infers group partitions based on predictions and class labels.

Returns:

Dictionary mapping group tuples to indices of examples belonging to each group.

Environment Inference for Invariance Learning (EIIL)

class EIIL(logits: Tensor, class_labels: List[int], num_steps: int, lr: float, device: device = device(type='cpu'), verbose: bool = False)

Bases: BaseGroupInference

Environment Inference for Invariant Learning: https://arxiv.org/abs/2010.07249

__init__(logits: Tensor, class_labels: List[int], num_steps: int, lr: float, device: device = device(type='cpu'), verbose: bool = False)

Initializes the EIILInference object.

Parameters:
  • logits (torch.Tensor) – The logits output of the model.

  • class_labels (List[int]) – The class labels for each sample.

  • num_steps (int) – The number of steps for training the soft environment assignment.

  • lr (float) – The learning rate for training.

  • device (torch.device, optional) – The device to use for training. Defaults to CPU.

  • verbose (bool, optional) – Whether to print training progress. Defaults to False.

infer_groups()

Performs EIIL inference to infer group partitions.

Returns:

The group partition based on EIIL inference.

Return type:

Dict[Tuple[int, int], List[int]]

Spread Spurious Attribute (SSA)

class SSA(spurious_unlabeled_dataset: BaseSpuCoCompatibleDataset, spurious_labeled_dataset: SpuriousTargetDatasetWrapper, model: Module, num_iters: int, labeled_valset_size: float = 0.5, tau_g_min: float = 0.95, lr: float = 0.01, weight_decay: float = 0.0005, batch_size: int = 64, num_splits: int = 3, device: device = device(type='cpu'), verbose: bool = False)

Bases: BaseGroupInference

Spread Spurious Attribute: https://arxiv.org/abs/2204.02070

__init__(spurious_unlabeled_dataset: BaseSpuCoCompatibleDataset, spurious_labeled_dataset: SpuriousTargetDatasetWrapper, model: Module, num_iters: int, labeled_valset_size: float = 0.5, tau_g_min: float = 0.95, lr: float = 0.01, weight_decay: float = 0.0005, batch_size: int = 64, num_splits: int = 3, device: device = device(type='cpu'), verbose: bool = False)

Initializes SSA.

Parameters:
  • spurious_unlabeled_dataset (Dataset) – The dataset containing spurious-labeled unlabeled samples.

  • spurious_labeled_dataset (SpuriousTargetDatasetWrapper) – The dataset containing spurious-labeled labeled samples.

  • model (nn.Module) – The PyTorch model to be used.

  • labeled_valset_size (float) – The size of the labeled validation set as a fraction of the total labeled dataset size.

  • num_iters (int) – The number of iterations for SSA.

  • tau_g_min (float) – The minimum value of the spurious correlation threshold tau_g.

  • lr (float) – The learning rate for optimization (default: 1e-2).

  • weight_decay (float) – The weight decay for optimization (default: 5e-4).

  • batch_size (int) – The batch size for training (default: 64).

  • num_splits (int) – The number of splits for cross-validation (default: 3).

  • device (torch.device) – The device to be used for training (default: CPU).

  • verbose (bool) – Whether to print training progress (default: False).

infer_groups() Dict[Tuple[int, int], List[int]]

Infer groups based on spurious attribute predictions.

Returns:

A dictionary mapping group labels to the indices of examples belonging to each group.

Return type:

Dict[Tuple[int, int], List[int]]

train_ssa(split_num: int) Module

Train an SSA model for the specified split.

Parameters:

split_num (int) – The index of the split to train the SSA model for.

Returns:

The trained SSA model.

Return type:

nn.Module

label_split(split_num: int, best_ssa_model: Module) array

Label the specified split using the best SSA model.

Parameters:
  • split_num (int) – The index of the split to label.

  • best_ssa_model (nn.Module) – The best SSA model obtained during training.

Returns:

An array of spurious labels for the split.

Return type:

np.array

class SSATrainer(ssa: SSA, split_num: int)

Bases: object

__init__(ssa: SSA, split_num: int)

Initializes SSATrainer. :param ssa: The SSA object containing the data, model etc. for training. :type ssa: SSA :param split_num: The index of the split to train the SSA model for. :type split_num: int

train() None

Trains model targetting spurious attribute for given split.

train_step(unlabeled_train_batch, labeled_train_batch)

Trains a single step of SSA for a given batch of unlabeled and labeled data.

Parameters:
  • unlabeled_train_batch (torch.Tensor) – The batch of unlabeled training data.

  • labeled_train_batch (torch.Tensor) – The batch of labeled training data.

validate()

Validates the SSA Model on the spurious attribute labeled validation set.

Correct-N-Contrast Inference (CNC)

class CorrectNContrastInference(trainset: Dataset, model: Module, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), verbose: bool = False)

Bases: BaseGroupInference

Correct-n-Contrast Inference: https://proceedings.mlr.press/v162/zhang22z.html

__init__(trainset: Dataset, model: Module, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), verbose: bool = False)

Initializes the CorrectNContrastInference object.

Parameters:
  • trainset (Dataset) – The training dataset.

  • model (nn.Module) – The model for training.

  • batch_size (int) – The batch size for training.

  • optimizer (optim.Optimizer) – The optimizer for training.

  • num_epochs (int) – The number of epochs for training.

  • device (torch.device, optional) – The device to use for training. Defaults to CPU.

  • verbose (bool, optional) – Whether to print training progress. Defaults to False.

infer_groups() Dict[Tuple[int, int], List[int]]

Performs Correct-n-Contrast inference to infer group partitions.

Returns:

The group partition based on Correct-n-Contrast inference.

Return type:

Dict[Tuple[int, int], List[int]]

SPARE

class SpareInference(Z: Tensor, class_labels: List[int] | None = None, cluster_alg: ClusterAlg = ClusterAlg.KMEANS, num_clusters: int = -1, max_clusters: int = -1, silhoutte_threshold: float = 0.9, high_sampling_power: int = 2, device: device = device(type='cpu'), verbose: bool = False)

Bases: Cluster

SPARE Inference: https://arxiv.org/abs/2305.18761

__init__(Z: Tensor, class_labels: List[int] | None = None, cluster_alg: ClusterAlg = ClusterAlg.KMEANS, num_clusters: int = -1, max_clusters: int = -1, silhoutte_threshold: float = 0.9, high_sampling_power: int = 2, device: device = device(type='cpu'), verbose: bool = False)

Initializes Spare Inference.

Parameters:
  • Z (torch.Tensor) – The output of the network.

  • class_labels (Optional[List[int]], optional) – Optional list of class labels for class-wise clustering. Defaults to None.

  • cluster_alg (ClusterAlg, optional) – The clustering algorithm to use. Defaults to ClusterAlg.KMEANS.

  • num_clusters (int, optional) – The number of clusters to create. Defaults to -1.

  • max_clusters (int, optional) – The maximum number of clusters to consider. Defaults to -1.

  • silhoutte_threshold (float, optional) – The silhouette threshold for determining the sampling powers. Defaults to 0.9.

  • high_sampling_power (int, optional) – The sampling power for the low-silhouette clusters. Defaults to 2.

  • device (torch.device, optional) – The device to run the clustering on. Defaults to torch.device(“cpu”).

  • verbose (bool, optional) – Whether to display progress and logging information. Defaults to False.

infer_groups() Dict[int, List[int]]

Infers the group partition based on the clustering results.

Returns:

The group partition.

Return type:

Dict[int, List[int]]

Returns:

The sampling powers for each group.

Return type:

List[int]

Base Group Inference

class BaseGroupInference

Bases: ABC

BaseGroupInference abstract class for inferring group partitions.

__init__()

Initializes BaseGroupInference.

abstract infer_groups() Dict[Tuple[int, int], List[int]]

Abstract method for inferring group partitions.

Returns:

Dictionary mapping group tuples to indices of examples belonging to each group.

process_cluster_partition(cluster_partition: Dict, class_index: int)

Processes cluster partition: - Converts keys from clusters into (class, spurious) format - Converts class indices from class-wise clustering into global (actual trainset) indices

Parameters:
  • cluster_partition – Dictionary mapping cluster labels to indices of examples.

  • class_index – Index of the class being processed.

Returns:

Processed group partition mapping group tuples to indices of examples belonging to each group.

static save_group_partition(group_partition: Dict[Tuple[int, int], List[int]], prefix: str)