spuco.models

Built-in Models and ModelFactory.

ModelFactory

Supported Models:

  • MLP

  • LeNet

  • BERT

  • DistilBERT

  • ResNet18

  • ResNet50

class Identity

Bases: Module

__init__()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class SupportedModels(value)

Bases: Enum

Enum listing all supported models.

MLP = 'mlp'
LeNet = 'lenet'
BERT = 'bert'
DistilBERT = 'distilbert'
ResNet18 = 'resnet18'
ResNet50 = 'resnet50'
model_factory(arch: str, input_shape: Tuple[int, int, int], num_classes: int, pretrained: bool = True)

Factory function to create a SpuCoModel based on the specified architecture.

Parameters:
  • arch (str) – The architecture name.

  • input_shape (Tuple[int, int, int]) – The shape of the input data in the format (channels, height, width).

  • num_classes (int) – The number of output classes.

  • pretrained (bool) – Whether to load pretrained weights. Default is True.

Returns:

A SpuCoModel instance.

Return type:

SpuCoModel

Raises:

NotImplementedError – If the specified architecture is not supported.

SpuCoModel

class SpuCoModel(backbone: Module, representation_dim: int, num_classes: int)

Bases: Module

Wrapper module to allow for methods that use penultimate layer embeddings to easily access this via backbone

__init__(backbone: Module, representation_dim: int, num_classes: int)

Initializes a SpuCoModel

Parameters:
  • backbone (torch.nn.Module) – The backbone network.

  • representation_dim (int) – The dimensionality of the penultimate layer embeddings.

  • num_classes (int) – The number of output classes.

forward(x)

Forward pass of the SpuCoModel.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Output tensor.

Return type:

torch.Tensor