class liqfit.modeling.ClassificationHead

(in_features: int, out_features: int, pooler: nn.Module, 
loss_func: nn.Module, bias: bool = True, temperature: int = 1.0, eps: float = 1e-5)


  • in_features (int) Number of input features.

  • out_features (int) Number of output features.

  • pooler (nn.Module): Pooling function to use in case the input is not multi-target.

  • loss_func (nn.Module): Loss function that will be called if labels are passed.

  • bias (bool): Whether to use bias in the nn.Linear layer or not.

  • temperature (int): Temperature that will be divided by the linear layer output to calibrate the output. (Defaults to 1.0)

  • eps (float): Epsilon will be added to the temperature for numerical stability. (Defaults to 1e-5).

Using LabelClassificationHead

For more flexibility in passing your loss function and your pooling method.

from liqfit.modeling.heads import ClassificationHead
import torch
from liqfit.losses import CrossEntropyLoss
from liqfit.modeling.pooling import GlobalAvgPooling

loss = CrossEntropyLoss(multi_target=False)
pooler = GlobalAvgPooling()
head = ClassificationHead(512, 20, loss_func=loss, pooler=pooler)
embeddings = torch.randn((1, 10, 512))
output = head(embeddings)

Last updated