### ADAPTED FROM https://github.com/rycolab/pareto-probing/blob/master/src/h02_learn/model/linear.py
from typing import Dict
import math
import numpy as np
import torch
from probe_ably.core.models import AbstractModel
from torch import Tensor, nn
[docs]class LinearModel(AbstractModel):
[docs] def __init__(self, params: Dict):
"""Initiate the Linear Model
Args:
params (Dict): Contains the parameters for initialization. Params data format is
.. code-block:: json
{
'representation_size': Dimension of the representation,
'dropout': Dropout of module,
'n_classes': Number of classes for classification,
'alpha': Alpha value to calculate the complexity of the module
}
"""
super().__init__(params)
self.dropout_p = params["dropout"]
self.alpha = params["alpha"]
self.linear = nn.Linear(self.representation_size, self.n_classes)
self.dropout = nn.Dropout(self.dropout_p)
self.criterion = nn.CrossEntropyLoss()
[docs] def forward(
self, representation: Tensor, labels: Tensor, eps=1e-5, **kwargs
) -> Dict[str, Tensor]:
"""Forward method
Args:
representation (Tensor): Representation tensors
labels (Tensor): Prediciton labels
Returns:
Dict[str, Tensor]: Return dictionary of {'loss': loss, 'preds': preds }
"""
representation = representation / (
representation.norm(p=2, dim=-1, keepdim=True) + eps
)
embeddings = self.dropout(representation)
logits = self.linear(embeddings)
preds = logits.max(1).indices
loss = (
self.criterion(logits, labels) / math.log(2)
) + self.alpha * self.get_norm()
return {"loss": loss, "preds": preds}
[docs] def get_complexity(self, **kwargs) -> Dict[str, float]:
"""Computes the Nuclear Norm complexity
Returns:
Dict[str, float]: Returns the complexity value of {'norm': nuclear norm score of model}
"""
return {"norm": float(self.get_norm().item())}
def get_norm(self) -> Tensor:
ext_matrix = torch.cat(
[self.linear.weight, self.linear.bias.unsqueeze(-1)], dim=1
)
penalty = torch.norm(ext_matrix, p="nuc")
return penalty
# def get_rank(self):
# ext_matrix = torch.cat([self.linear.weight, self.linear.bias.unsqueeze(-1)], dim=1)
# _, svd_matrix, _ = np.linalg.svd(ext_matrix.cpu().numpy())
# rank = np.sum(svd_matrix > 1e-3)
# return rank