from liqfit.modeling.heads import LiqFitHead
from liqfit.modeling.heads import HeadOutput
from torch import nn
class MyOwnDownstreamHead(LiqFitHead):
def __init__(in_features, out_features):
self.linear = nn.Linear(in_features, out_features)
def compute_loss(self, logits, labels):
# your loss function implementation
def forward(self, embeddings, labels=None):
# your forward implementation.
return HeadOutput(...)