spuco.models
Built-in Models and ModelFactory.
ModelFactory
Supported Models:
MLP
LeNet
BERT
DistilBERT
ResNet18
ResNet50
- class Identity
Bases:
Module
- __init__()
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class SupportedModels(value)
Bases:
Enum
Enum listing all supported models.
- MLP = 'mlp'
- LeNet = 'lenet'
- BERT = 'bert'
- DistilBERT = 'distilbert'
- ResNet18 = 'resnet18'
- ResNet50 = 'resnet50'
- model_factory(arch: str, input_shape: Tuple[int, int, int], num_classes: int, pretrained: bool = True)
Factory function to create a SpuCoModel based on the specified architecture.
- Parameters:
- Returns:
A SpuCoModel instance.
- Return type:
- Raises:
NotImplementedError – If the specified architecture is not supported.
SpuCoModel
- class SpuCoModel(backbone: Module, representation_dim: int, num_classes: int)
Bases:
Module
Wrapper module to allow for methods that use penultimate layer embeddings to easily access this via backbone
- forward(x)
Forward pass of the SpuCoModel.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Output tensor.
- Return type:
torch.Tensor