Source code for probe_ably.core.models.abstract_model

from abc import ABC, abstractmethod
from typing import Dict

from torch import Tensor
from torch.nn import Module


[docs]class AbstractModel(Module, ABC): subclasses = {} def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) name = cls.__module__ + "." + cls.__qualname__ if name in cls.subclasses: message = "Cannot register module %s as %s; name already in use" % ( cls.__module__, cls.__module__, ) raise Exception(message) cls.subclasses[name] = cls
[docs] def __init__(self, params): """Abstract class initialization method Args: params ([type]): Contains the parameters for initialization. Params data format is .. code-block:: json { 'representation_size': Dimension of the representation, 'n_layers': Number of MLP Layers, } """ super().__init__() self.representation_size = params["representation_size"] self.n_classes = params["n_classes"]
[docs] @abstractmethod def forward( self, representation: Tensor, labels: Tensor, **kwargs ) -> Dict[str, Tensor]: """Abstract class forward method Args: representation (Tensor): Representation tensors labels (Tensor): Prediciton labels Returns: Dict[str, Tensor]: Return dictionary of {'loss': losss, 'preds': preds } """ ...
[docs] @abstractmethod def get_complexity(self, **kwargs) -> Dict[str, float]: """Computes the complexity Returns: Dict[str,float]: Returns dictionary of {"complexity_measure1": value1, "complexity_measure2": value2} """ ...