Knowledgator Docs
GitHubDiscord
  • 🛎️Welcome
  • ⚙️Models
    • 🧮Comprehend-it
      • Comprehend_it-base
      • Comprehend_it-multilingual-t5-base
    • 🦎UTC
  • 👷Frameworks
    • 💧LiqFit
      • Quick Start
      • Benchmarks
      • API Reference
        • Collators
          • NLICollator
          • Creating custom collator
        • Datasets
          • NLIDataset
        • Losses
          • Focal Loss
          • Binary Cross Entropy
          • Cross Entropy Loss
        • Modeling
          • LiqFitBackbone
          • LiqFitModel
        • Downstream Heads
          • LiqFitHead
          • LabelClassificationHead
          • ClassClassificationHead
          • ClassificationHead
        • Pooling
          • GlobalMaxPooling1D
          • GlobalAbsAvgPooling1D
          • GlobalAbsMaxPooling1D
          • GlobalRMSPooling1D
          • GlobalSumPooling1D
          • GlobalAvgPooling1D
          • FirstTokenPooling1D
        • Models
          • Deberta
          • T5
        • Pipelines
          • ZeroShotClassificationPipeline
  • 📚Datasets
    • Biotech news dataset
  • 👩‍🔧Support
  • API Reference
    • Comprehend-it API
    • Entity extraction
      • /fast
      • /deterministic
      • /advanced
    • Token searcher
    • Web2Meaning
    • Web2Meaning2
    • Relation extraction
    • Text2Table
      • /web2text
      • /text_preprocessing
      • /text2table
      • /merge_tables
Powered by GitBook
On this page
  1. Frameworks
  2. LiqFit
  3. API Reference
  4. Downstream Heads

ClassificationHead

class liqfit.modeling.ClassificationHead

(in_features: int, out_features: int, pooler: nn.Module, 
loss_func: nn.Module, bias: bool = True, temperature: int = 1.0, eps: float = 1e-5)

Parameters:

  • in_features (int) Number of input features.

  • out_features (int) Number of output features.

  • pooler (nn.Module): Pooling function to use in case the input is not multi-target.

  • loss_func (nn.Module): Loss function that will be called if labels are passed.

  • bias (bool): Whether to use bias in the nn.Linear layer or not.

  • temperature (int): Temperature that will be divided by the linear layer output to calibrate the output. (Defaults to 1.0)

  • eps (float): Epsilon will be added to the temperature for numerical stability. (Defaults to 1e-5).

Using LabelClassificationHead

For more flexibility in passing your loss function and your pooling method.

from liqfit.modeling.heads import ClassificationHead
import torch
from liqfit.losses import CrossEntropyLoss
from liqfit.modeling.pooling import GlobalAvgPooling

loss = CrossEntropyLoss(multi_target=False)
pooler = GlobalAvgPooling()
head = ClassificationHead(512, 20, loss_func=loss, pooler=pooler)
embeddings = torch.randn((1, 10, 512))
output = head(embeddings)
PreviousClassClassificationHeadNextPooling

Last updated 1 year ago

👷
💧