LiqFitModel

class liqfit.modeling.Model

(config: PretrainedConfig,
backbone: LiqFitBackbone | nn.Module | PreTrainedModel,
head: Optional[LiqFitHead] = None,
loss_func: Optional[nn.Module] = None,
normalize_backbone_embeddings: bool = False,
labels_name: str = "labels",
push_backbone_only: bool = False)

Parameters:

  • config (PretrainedConfig): Backbone configuration.

  • backbone (nn.Module): Pretrained model (backbone).

  • head (LiqFitHead): Downstream head.

  • loss_func (Optional[nn.Module]): Loss function that will be called after each forward pass (if labels are passed). (Defaults to None).

  • normalize_backbone_embeddings (bool): Whether to normalize the output embeddings from the backbone or not using the torch.nn.functional.normalize. (Defaults to False).

  • labels_name (str): Labels parameter name that was passed to the forward method

  • push_backbone_only (bool, optional): Whether to push the wrapped model or only push the backbone model to Hugging Face.

Using `LiqFitModel` class with `transformers` library.

from liqfit.modeling import LiqFitBackbone
from liqfit.modeling import LiqFitModel
from transformers import AutoModel

backbone_model = AutoModel.from_pretrained(...)
model = LiqFitModel(backbone_model.config, backbone_model)

Using `LiqFitModel` with one of the available heads.

from liqfit.modeling import LiqFitBackbone
from liqfit.modeling import LiqFitModel
from transformers import AutoModel

class MyBackboneModel(LiqFitBackbone):
    def __init__(self):
        backbone_model = AutoModel.from_pretrained(...)
        super.__init__(backbone_model.config, backbone_model)
    
    def encode(self, input_ids, attention_mask):
        output = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output[0]
        return last_hidden_state

backbone = MyBackboneModel()
head = ClassClassificationHead(backbone.config.hidden_size, 3, multi_target=True)

model = LiqFitModel(backbone.config, backbone, head)

x = torch.randint(0, 20, (1, 20))
out = model(x)

Last updated