spuco.group_inference
Group Inference Methods.
Cluster
- class Cluster(Z: Tensor, class_labels: List[int] | None = None, cluster_alg: ClusterAlg = ClusterAlg.KMEANS, num_clusters: int = -1, max_clusters: int = -1, device: device = device(type='cpu'), verbose: bool = False)
Bases:
BaseGroupInference
Clustering-based Group Inference
- __init__(Z: Tensor, class_labels: List[int] | None = None, cluster_alg: ClusterAlg = ClusterAlg.KMEANS, num_clusters: int = -1, max_clusters: int = -1, device: device = device(type='cpu'), verbose: bool = False)
Initializes the Cluster object.
- Parameters:
Z (torch.Tensor) – The input tensor for clustering.
class_labels (Optional[List[int]], optional) – Optional list of class labels for class-wise clustering. Defaults to None.
cluster_alg (ClusterAlg, optional) – The clustering algorithm to use. Defaults to ClusterAlg.KMEANS.
num_clusters (int, optional) – The number of clusters to create. Defaults to -1.
max_clusters (int, optional) – The maximum number of clusters to consider. Defaults to -1.
device (torch.device, optional) – The device to run the clustering on. Defaults to torch.device(“cpu”).
verbose (bool, optional) – Whether to display progress and logging information. Defaults to False.
- silhouette(Z)
Uses the silhouette score to determine the optimal number of clusters and perform clustering based on self.cluster_alg.
- Parameters:
Z (torch.Tensor) – The input data for clustering.
- Returns:
The cluster partition based on the optimal number of clusters.
- Return type:
List[int]
- kmeans(Z, num_clusters: int = -1)
Performs K-means clustering on the input data.
- Parameters:
Z (torch.Tensor) – The input data for clustering.
num_clusters (int, optional) – The number of clusters to create. If not specified, the value from the object will be used.
- Returns:
The cluster labels and partition based on the K-means clustering.
- Return type:
Tuple[np.ndarray, List[List[int]]]
- kmedoids(Z, similiarity_matrix: Tensor, num_clusters=-1)
Performs K-medoids clustering on the input data.
- Parameters:
Z (torch.Tensor) – The input data for clustering.
similiarity_matrix (torch.Tensor) – The similarity matrix for pairwise similarities between data points.
num_clusters (int, optional) – The number of clusters to create. If not specified, the value from the object will be used.
- Returns:
The cluster labels and partition based on the K-medoids clustering.
- Return type:
Tuple[np.ndarray, List[List[int]]]
Just Train Twice (JTT) Inference
- class JTTInference(predictions: List[int], class_labels: List[int])
Bases:
BaseGroupInference
Just Train Twice Inference: https://arxiv.org/abs/2107.09044
Environment Inference for Invariance Learning (EIIL)
- class EIIL(logits: Tensor, class_labels: List[int], num_steps: int, lr: float, device: device = device(type='cpu'), verbose: bool = False)
Bases:
BaseGroupInference
Environment Inference for Invariant Learning: https://arxiv.org/abs/2010.07249
- __init__(logits: Tensor, class_labels: List[int], num_steps: int, lr: float, device: device = device(type='cpu'), verbose: bool = False)
Initializes the EIILInference object.
- Parameters:
logits (torch.Tensor) – The logits output of the model.
class_labels (List[int]) – The class labels for each sample.
num_steps (int) – The number of steps for training the soft environment assignment.
lr (float) – The learning rate for training.
device (torch.device, optional) – The device to use for training. Defaults to CPU.
verbose (bool, optional) – Whether to print training progress. Defaults to False.
Spread Spurious Attribute (SSA)
- class SSA(spurious_unlabeled_dataset: BaseSpuCoCompatibleDataset, spurious_labeled_dataset: SpuriousTargetDatasetWrapper, model: Module, num_iters: int, labeled_valset_size: float = 0.5, tau_g_min: float = 0.95, lr: float = 0.01, weight_decay: float = 0.0005, batch_size: int = 64, num_splits: int = 3, device: device = device(type='cpu'), verbose: bool = False)
Bases:
BaseGroupInference
Spread Spurious Attribute: https://arxiv.org/abs/2204.02070
- __init__(spurious_unlabeled_dataset: BaseSpuCoCompatibleDataset, spurious_labeled_dataset: SpuriousTargetDatasetWrapper, model: Module, num_iters: int, labeled_valset_size: float = 0.5, tau_g_min: float = 0.95, lr: float = 0.01, weight_decay: float = 0.0005, batch_size: int = 64, num_splits: int = 3, device: device = device(type='cpu'), verbose: bool = False)
Initializes SSA.
- Parameters:
spurious_unlabeled_dataset (Dataset) – The dataset containing spurious-labeled unlabeled samples.
spurious_labeled_dataset (SpuriousTargetDatasetWrapper) – The dataset containing spurious-labeled labeled samples.
model (nn.Module) – The PyTorch model to be used.
labeled_valset_size (float) – The size of the labeled validation set as a fraction of the total labeled dataset size.
num_iters (int) – The number of iterations for SSA.
tau_g_min (float) – The minimum value of the spurious correlation threshold tau_g.
lr (float) – The learning rate for optimization (default: 1e-2).
weight_decay (float) – The weight decay for optimization (default: 5e-4).
batch_size (int) – The batch size for training (default: 64).
num_splits (int) – The number of splits for cross-validation (default: 3).
device (torch.device) – The device to be used for training (default: CPU).
verbose (bool) – Whether to print training progress (default: False).
- infer_groups() Dict[Tuple[int, int], List[int]]
Infer groups based on spurious attribute predictions.
- train_ssa(split_num: int) Module
Train an SSA model for the specified split.
- Parameters:
split_num (int) – The index of the split to train the SSA model for.
- Returns:
The trained SSA model.
- Return type:
nn.Module
- label_split(split_num: int, best_ssa_model: Module) array
Label the specified split using the best SSA model.
- Parameters:
split_num (int) – The index of the split to label.
best_ssa_model (nn.Module) – The best SSA model obtained during training.
- Returns:
An array of spurious labels for the split.
- Return type:
np.array
- class SSATrainer(ssa: SSA, split_num: int)
Bases:
object
- __init__(ssa: SSA, split_num: int)
Initializes SSATrainer. :param ssa: The SSA object containing the data, model etc. for training. :type ssa: SSA :param split_num: The index of the split to train the SSA model for. :type split_num: int
- train_step(unlabeled_train_batch, labeled_train_batch)
Trains a single step of SSA for a given batch of unlabeled and labeled data.
- Parameters:
unlabeled_train_batch (torch.Tensor) – The batch of unlabeled training data.
labeled_train_batch (torch.Tensor) – The batch of labeled training data.
- validate()
Validates the SSA Model on the spurious attribute labeled validation set.
Correct-N-Contrast Inference (CNC)
- class CorrectNContrastInference(trainset: Dataset, model: Module, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), verbose: bool = False)
Bases:
BaseGroupInference
Correct-n-Contrast Inference: https://proceedings.mlr.press/v162/zhang22z.html
- __init__(trainset: Dataset, model: Module, batch_size: int, optimizer: Optimizer, num_epochs: int, device: device = device(type='cpu'), verbose: bool = False)
Initializes the CorrectNContrastInference object.
- Parameters:
trainset (Dataset) – The training dataset.
model (nn.Module) – The model for training.
batch_size (int) – The batch size for training.
optimizer (optim.Optimizer) – The optimizer for training.
num_epochs (int) – The number of epochs for training.
device (torch.device, optional) – The device to use for training. Defaults to CPU.
verbose (bool, optional) – Whether to print training progress. Defaults to False.
SPARE
- class SpareInference(Z: Tensor, class_labels: List[int] | None = None, cluster_alg: ClusterAlg = ClusterAlg.KMEANS, num_clusters: int = -1, max_clusters: int = -1, silhoutte_threshold: float = 0.9, high_sampling_power: int = 2, device: device = device(type='cpu'), verbose: bool = False)
Bases:
Cluster
SPARE Inference: https://arxiv.org/abs/2305.18761
- __init__(Z: Tensor, class_labels: List[int] | None = None, cluster_alg: ClusterAlg = ClusterAlg.KMEANS, num_clusters: int = -1, max_clusters: int = -1, silhoutte_threshold: float = 0.9, high_sampling_power: int = 2, device: device = device(type='cpu'), verbose: bool = False)
Initializes Spare Inference.
- Parameters:
Z (torch.Tensor) – The output of the network.
class_labels (Optional[List[int]], optional) – Optional list of class labels for class-wise clustering. Defaults to None.
cluster_alg (ClusterAlg, optional) – The clustering algorithm to use. Defaults to ClusterAlg.KMEANS.
num_clusters (int, optional) – The number of clusters to create. Defaults to -1.
max_clusters (int, optional) – The maximum number of clusters to consider. Defaults to -1.
silhoutte_threshold (float, optional) – The silhouette threshold for determining the sampling powers. Defaults to 0.9.
high_sampling_power (int, optional) – The sampling power for the low-silhouette clusters. Defaults to 2.
device (torch.device, optional) – The device to run the clustering on. Defaults to torch.device(“cpu”).
verbose (bool, optional) – Whether to display progress and logging information. Defaults to False.
Base Group Inference
- class BaseGroupInference
Bases:
ABC
BaseGroupInference abstract class for inferring group partitions.
- __init__()
Initializes BaseGroupInference.
- abstract infer_groups() Dict[Tuple[int, int], List[int]]
Abstract method for inferring group partitions.
- Returns:
Dictionary mapping group tuples to indices of examples belonging to each group.
- process_cluster_partition(cluster_partition: Dict, class_index: int)
Processes cluster partition: - Converts keys from clusters into (class, spurious) format - Converts class indices from class-wise clustering into global (actual trainset) indices
- Parameters:
cluster_partition – Dictionary mapping cluster labels to indices of examples.
class_index – Index of the class being processed.
- Returns:
Processed group partition mapping group tuples to indices of examples belonging to each group.