from liqfit.modeling.heads import LiqFitHeadfrom liqfit.modeling.heads import HeadOutputfrom torch import nnclassMyOwnDownstreamHead(LiqFitHead):def__init__(in_features,out_features): self.linear = nn.Linear(in_features, out_features)defcompute_loss(self,logits,labels):# your loss function implementationdefforward(self,embeddings,labels=None):# your forward implementation.returnHeadOutput(...)