spuco.models
Built-in Models and ModelFactory.
ModelFactory
Supported Models:
MLP
LeNet
BERT
DistilBERT
ResNet18
ResNet50
- class Identity
Bases:
Module- __init__()
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class SupportedModels(value)
Bases:
EnumEnum listing all supported models.
- MLP = 'mlp'
- LeNet = 'lenet'
- BERT = 'bert'
- DistilBERT = 'distilbert'
- ResNet18 = 'resnet18'
- ResNet50 = 'resnet50'
- model_factory(arch: str, input_shape: Tuple[int, int, int], num_classes: int, pretrained: bool = True)
Factory function to create a SpuCoModel based on the specified architecture.
- Parameters:
- Returns:
A SpuCoModel instance.
- Return type:
- Raises:
NotImplementedError – If the specified architecture is not supported.
SpuCoModel
- class SpuCoModel(backbone: Module, representation_dim: int, num_classes: int)
Bases:
ModuleWrapper module to allow for methods that use penultimate layer embeddings to easily access this via backbone
- forward(x)
Forward pass of the SpuCoModel.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Output tensor.
- Return type:
torch.Tensor