Skip to main content

Training

Quickstart

Installation

To install GLiClass, run the following command:

pip install gliclass

Base Training Script

Load pretrained model

import torch
from gliclass import GLiClassModel
from transformers import AutoTokenizer

device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')

model_name = "knowledgator/gliclass-small-v1.0"
model = GLiClassModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Define Dataset

The dataset for training GLiClass must contain text field, which represents text for classification, all_labels field, whic stands for all labels to classify from and true_labels which will represent the correct labels for given text. For more info about datasets please visit our datasets page

from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding, AugmentationConfig

data = [
{
"text": "A new machine learning platform automates complex data workflows but faces integration issues.",
"all_labels": ["AI", "automation", "data_analysis", "usability", "integration"],
"true_labels": ["AI", "integration", "automation"]
}
]

# Configure data augmentation (disabled by default)
augment_config = AugmentationConfig(enabled=False)

train_dataset = GLiClassDataset(
data,
tokenizer,
augment_config,
label_to_description={}, # optional: map labels to descriptions
max_length=1024,
problem_type='multi_label_classification',
)

# Data collator
data_collator = DataCollatorWithPadding(device=device)
Data Augmentation

Enable augmentation to improve training robustness by randomly modifying labels and text during training:

augment_config = AugmentationConfig(
enabled=True,
random_label_removal_prob=0.1, # randomly remove labels
random_label_addition_prob=0.1, # randomly add labels
random_text_addition_prob=0.0, # randomly add text
random_add_description_prob=0.1, # add label descriptions
random_add_synonyms_prob=0.0, # add synonym variations
random_add_examples_prob=0.0, # add example texts
max_num_examples=5, # max examples to add
)
Expected Output
Total labels:  5

Define functions for metrics computation

import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def compute_metrics_multi_label(p):
predictions, labels = p
labels = labels.reshape(-1)
predictions = predictions.reshape(-1)
preds = (predictions > 0.5).astype(int)
labels = np.where(labels>0.5, 1, 0)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}

Train the model

from gliclass.training import TrainingArguments, Trainer

# Training arguments
training_args = TrainingArguments(
output_dir="my-awesome-gliclass-model",
learning_rate=1e-5,
weight_decay=0.01,
others_lr=1e-5,
others_weight_decay=0.01,
lr_scheduler_type="cosine",
per_device_eval_batch_size=8,
num_train_epochs=1,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=False, # Set to True if you want to use mixed precision training
)

# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics_multi_label
)

# Run training
trainer.train()

Full Base Training Script [source]

The following script could be used both for training from scratch and fine-tuning pretrained model. It supports data augmentation, LoRA parameter-efficient fine-tuning, and EWC continual learning:

import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import numpy as np
import argparse
import json

from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import transformers
from transformers import AutoTokenizer, AutoConfig
from torch.utils.data import WeightedRandomSampler
from packaging import version
from peft import LoraConfig, get_peft_model, TaskType

import random
import torch

from gliclass import GLiClassModelConfig, GLiClassModel
from gliclass.training import TrainingArguments, Trainer
from gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset, AugmentationConfig

class CustomTrainer(Trainer):
"""Trainer with weighted random sampling support."""

def __init__(self, *args, use_weighted_sampling=False, **kwargs):
super().__init__(*args, **kwargs)
self.use_weighted_sampling = use_weighted_sampling

def _get_train_sampler(self, train_dataset) -> torch.utils.data.Sampler:
if not self.use_weighted_sampling:
return super()._get_train_sampler()

weights = train_dataset.get_diversity()
return WeightedRandomSampler(
weights=weights,
num_samples=len(train_dataset),
replacement=True
)

def compute_metrics(p, problem_type='multi_label_classification'):
"""Compute evaluation metrics.

Args:
p: Predictions tuple (predictions, labels)
problem_type: Type of classification problem

Returns:
Dictionary of metrics
"""
predictions, labels = p
labels = labels.reshape(-1)

if problem_type == 'single_label_classification':
preds = np.argmax(predictions, axis=1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}

elif problem_type == 'multi_label_classification':
predictions = predictions.reshape(-1)
preds = (predictions > 0.5).astype(int)
labels = np.where(labels > 0.5, 1, 0)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
else:
raise NotImplementedError(f"{problem_type} is not implemented.")


def load_dataset(data_path: str) -> list:
"""Load dataset from JSON or JSONL file.

Args:
data_path: Path to JSON or JSONL data file

Returns:
List of data samples
"""
with open(data_path, 'r') as f:
if data_path.endswith('.jsonl'):
data = [json.loads(line) for line in f if line.strip()]
else:
data = json.load(f)
return data


def get_lora_target_modules(model, args):
"""Determine target modules for LoRA based on model architecture.

If --lora_target_modules is provided, use those. Otherwise, auto-detect
all linear layers (excluding the final classification head).

Args:
model: The GLiClassModel
args: Parsed arguments

Returns:
List of module name patterns to target
"""
if args.lora_target_modules:
return args.lora_target_modules

# Collect short names grouped by whether they are Linear or not
linear_names = set()
non_linear_names = set()
for name, module in model.named_modules():
short_name = name.split('.')[-1]
if isinstance(module, torch.nn.Linear):
linear_names.add(short_name)
else:
non_linear_names.add(short_name)

# Only keep names that EXCLUSIVELY refer to Linear layers
# (avoid names like "layer" or "output" that also match container modules)
target_modules = linear_names - non_linear_names

# Remove common classification head names to avoid adapting the output layer
head_names = {'classifier', 'score', 'out_proj', 'dense_out', 'head'}
target_modules -= head_names

target_modules = sorted(target_modules)
print(f"Auto-detected LoRA target modules: {target_modules}")
return target_modules


def apply_lora(model, args):
"""Apply LoRA adapters to the model.

Args:
model: The GLiClassModel
args: Parsed arguments containing LoRA config

Returns:
PeftModel with LoRA adapters applied
"""
target_modules = get_lora_target_modules(model, args)

lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=target_modules,
lora_dropout=args.lora_dropout,
bias=args.lora_bias,
modules_to_save=args.lora_modules_to_save,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

return model


def main(args):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

# Load or create model
if args.model_name is not None:
model = GLiClassModel.from_pretrained(
args.model_name,
focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma,
focal_loss_reduction=args.focal_loss_reduction
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
else:
tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)
encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)

label_model_config = None
if args.label_model_name is not None:
label_model_config = AutoConfig.from_pretrained(args.label_model_name)

glicalss_config = GLiClassModelConfig(
encoder_config=encoder_config,
encoder_model=args.encoder_model_name,
label_model_name=args.label_model_name,
label_model_config=label_model_config,
class_token_index=len(tokenizer),
text_token_index=len(tokenizer)+1,
example_token_index=len(tokenizer)+2,
pooling_strategy=args.pooler_type,
class_token_pooling=args.class_token_pooling,
scorer_type=args.scorer_type,
use_lstm=args.use_lstm,
focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma,
focal_loss_reduction=args.focal_loss_reduction,
contrastive_loss_coef=args.contrastive_loss_coef,
normalize_features=args.normalize_features,
extract_text_features=args.extract_text_features,
architecture_type=args.architecture_type,
prompt_first=args.prompt_first,
squeeze_layers=args.squeeze_layers,
layer_wise=args.layer_wise,
encoder_layer_id=args.encoder_layer_id,
shuffle_labels=args.shuffle_labels,
dropout=args.dropout,
use_segment_embeddings=args.use_segment_embeddings,
)

model = GLiClassModel(glicalss_config, from_pretrained=True).to(dtype=torch.float32)

if args.architecture_type in {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder', 'encoder-decoder-cls'}:
new_words = ["<<LABEL>>", "<<SEP>>", "<<EXAMPLE>>"]
tokenizer.add_tokens(new_words, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

model.to(device)

# Apply LoRA if enabled
if args.use_lora:
print("\n--- Applying LoRA adapters ---")
model = apply_lora(model, args)
print("--- LoRA adapters applied ---\n")

# Get labels tokenizer if needed
if model.config.label_model_name is not None:
labels_tokenizer = AutoTokenizer.from_pretrained(model.config.label_model_name)
else:
labels_tokenizer = None

model.config.problem_type = args.problem_type

# Load current training data
data = load_dataset(args.data_path)
data = [item for item in data if len(item['text'])//4<2048]

print(f'Dataset size: {len(data)}')

random.shuffle(data)
print('Dataset is shuffled...')

train_data = data[:int(len(data) * 0.9)]
test_data = data[int(len(data) * 0.9):]
print('Dataset is splitted...')

# Create augmentation config with all parameters
augment_config = AugmentationConfig(
enabled=args.enable_augmentation,
random_label_removal_prob=args.random_label_removal_prob,
random_label_addition_prob=args.random_label_addition_prob,
random_text_addition_prob=args.random_text_addition_prob,
random_add_description_prob=args.random_add_description_prob,
random_add_synonyms_prob=args.random_add_synonyms_prob,
random_add_examples_prob=args.random_add_examples_prob,
max_num_examples=args.max_num_examples
)

if args.labels_desc_path is not None:
labels_descriptions = load_dataset(args.labels_desc_path)
label_to_description = {item.get("label"): item for item in labels_descriptions}
else:
label_to_description = {}

train_dataset = GLiClassDataset(train_data, tokenizer, augment_config,
label_to_description, args.max_length,
args.problem_type, args.architecture_type,
args.prompt_first, labels_tokenizer=labels_tokenizer)

# Disable augmentation for test dataset
test_augment_config = AugmentationConfig(enabled=False)
test_dataset = GLiClassDataset(test_data, tokenizer, test_augment_config,
label_to_description,
args.max_length, args.problem_type,
args.architecture_type, args.prompt_first,
labels_tokenizer = labels_tokenizer)

# Load previous dataset for EWC if provided
prev_dataset = None
if args.use_ewc and args.prev_data_path is not None:
print(f'Loading previous dataset for EWC from: {args.prev_data_path}')
prev_data = load_dataset(args.prev_data_path)
print(f'Previous dataset size: {len(prev_data)}')

# Use a subset if specified
if args.ewc_fisher_samples is not None and args.ewc_fisher_samples < len(prev_data):
random.shuffle(prev_data)
prev_data = prev_data[:args.ewc_fisher_samples]
print(f'Using {len(prev_data)} samples for Fisher estimation')

prev_dataset = GLiClassDataset(prev_data, tokenizer, test_augment_config,
label_to_description,
args.max_length, args.problem_type,
args.architecture_type, args.prompt_first,
labels_tokenizer = labels_tokenizer)

data_collator = DataCollatorWithPadding(device=device)

# Create training arguments with EWC parameters
training_args = TrainingArguments(
output_dir=args.save_path,
learning_rate=args.encoder_lr,
weight_decay=args.encoder_weight_decay,
others_lr=args.others_lr,
others_weight_decay=args.others_weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
warmup_ratio=args.warmup_ratio,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
num_train_epochs=args.num_epochs,
save_steps=args.save_steps,
save_total_limit=args.save_total_limit,
dataloader_num_workers=args.num_workers,
logging_steps=100,
use_cpu=False,
report_to="none",
fp16=args.fp16,
save_only_model=True,
use_ewc=args.use_ewc,
ewc_lambda=args.ewc_lambda,
ewc_fisher_samples=args.ewc_fisher_samples,
ewc_normalize_fisher=args.ewc_normalize_fisher,
ewc_gamma=args.ewc_gamma,
)

# Create compute_metrics function with problem_type closure
def compute_metrics_fn(p):
return compute_metrics(p, args.problem_type)

# Create trainer with EWC support
# Handle version differences between transformers v4 and v5
trainer_kwargs = {
"model": model,
"args": training_args,
"train_dataset": train_dataset,
"eval_dataset": test_dataset,
"data_collator": data_collator,
"compute_metrics": compute_metrics_fn,
"prev_dataset": prev_dataset, # Pass previous dataset for EWC
}

if version.parse(transformers.__version__) < version.parse("5.0.0"):
trainer_kwargs["tokenizer"] = tokenizer
else:
trainer_kwargs["processing_class"] = tokenizer

trainer = CustomTrainer(**trainer_kwargs)

# Print EWC status
if args.use_ewc:
if args.prev_data_path is not None:
print(f'\nEWC enabled with lambda={args.ewc_lambda}')
else:
print('\nWarning: EWC is enabled but no previous data path provided. EWC will not be used.')

trainer.train()

# Save final model
final_output_dir = os.path.join(args.save_path, 'final_model')

if args.use_lora:
print("\n--- Merging LoRA weights into base model ---")
# Merge LoRA adapters into the base model weights
merged_model = model.merge_and_unload()
print("--- LoRA weights merged successfully ---")

# Save the merged model as a standard model (no adapter artifacts)
merged_model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
print(f'Merged model saved to {final_output_dir}')

# Optionally save the unmerged LoRA adapter separately
if args.save_lora_adapter:
adapter_dir = os.path.join(args.save_path, 'lora_adapter')
model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
print(f'LoRA adapter also saved separately to {adapter_dir}')
else:
model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
print(f'Final model saved to {final_output_dir}')


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train GLiClass model with optional LoRA and EWC for continual learning')

# Model arguments
parser.add_argument('--model_name', type=str, default=None,
help='Pretrained model name or path')
parser.add_argument('--encoder_model_name', type=str, default='microsoft/deberta-v3-base',
help='Encoder model name')
parser.add_argument('--label_model_name', type=str, default="BAAI/bge-small-en-v1.5",
help='Label model name')

# Path arguments
parser.add_argument('--save_path', type=str, default='models/',
help='Path to save trained model')
parser.add_argument('--data_path', type=str, default='data/zero-cats.json',
help='Path to training data JSON or JSONL file')
parser.add_argument('--prev_data_path', type=str, default=None,
help='Path to previous task data for EWC (required if use_ewc=True)')
parser.add_argument('--labels_desc_path', type=str, default=None)

# Model architecture arguments
parser.add_argument('--problem_type', type=str, default='multi_label_classification',
choices=['single_label_classification', 'multi_label_classification'])
parser.add_argument('--pooler_type', type=str, default='first')
parser.add_argument('--scorer_type', type=str, default='mlp')
parser.add_argument('--architecture_type', type=str, default='uni-encoder')
parser.add_argument('--class_token_pooling', type=str, default='average')
parser.add_argument('--normalize_features', type=bool, default=False)
parser.add_argument('--extract_text_features', type=bool, default=False)
parser.add_argument('--prompt_first', type=bool, default=True)
parser.add_argument('--use_lstm', type=bool, default=False)
parser.add_argument('--squeeze_layers', type=bool, default=False)
parser.add_argument('--layer_wise', type=bool, default=False)
parser.add_argument('--encoder_layer_id', type=int, default=-1)
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--shuffle_labels', type=bool, default=True)
parser.add_argument('--use_segment_embeddings', type=bool, default=True)

# Training arguments
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--gradient_accumulation_steps', type=int, default=2)
parser.add_argument('--encoder_lr', type=float, default=1e-5)
parser.add_argument('--others_lr', type=float, default=1e-5)
parser.add_argument('--encoder_weight_decay', type=float, default=0.01)
parser.add_argument('--others_weight_decay', type=float, default=0.01)
parser.add_argument('--warmup_ratio', type=float, default=0.05)
parser.add_argument('--lr_scheduler_type', type=str, default='cosine')
parser.add_argument('--max_length', type=int, default=4096)
parser.add_argument('--save_steps', type=int, default=1000)
parser.add_argument('--save_total_limit', type=int, default=3)
parser.add_argument('--num_workers', type=int, default=12)
parser.add_argument('--fp16', type=bool, default=False)

# Augmentation parameters
parser.add_argument('--enable_augmentation', type=bool, default=False)
parser.add_argument('--random_label_removal_prob', type=float, default=0.0)
parser.add_argument('--random_label_addition_prob', type=float, default=0.0)
parser.add_argument('--random_text_addition_prob', type=float, default=0.0)
parser.add_argument('--random_add_description_prob', type=float, default=0.1)
parser.add_argument('--random_add_synonyms_prob', type=float, default=0.0)
parser.add_argument('--random_add_examples_prob', type=float, default=0.0)
parser.add_argument('--max_num_examples', type=int, default=5)

# Loss arguments
parser.add_argument('--focal_loss_alpha', type=float, default=0.8)
parser.add_argument('--focal_loss_gamma', type=float, default=-1)
parser.add_argument('--focal_loss_reduction', type=str, default='none',
choices=['none', 'mean', 'sum'])
parser.add_argument('--contrastive_loss_coef', type=float, default=0.)

# LoRA arguments
parser.add_argument('--use_lora', action='store_true',
help='Enable LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning')
parser.add_argument('--lora_r', type=int, default=1024,
help='LoRA rank (dimensionality of low-rank matrices)')
parser.add_argument('--lora_alpha', type=int, default=2048,
help='LoRA alpha (scaling factor, effective scaling = alpha/r)')
parser.add_argument('--lora_dropout', type=float, default=0.05,
help='Dropout probability for LoRA layers')
parser.add_argument('--lora_bias', type=str, default='none',
choices=['none', 'all', 'lora_only'],
help='Which bias parameters to train: none, all, or lora_only')
parser.add_argument('--lora_target_modules', type=str, nargs='+', default=None,
help='List of module names to apply LoRA to (e.g., query_proj value_proj). '
'If not specified, auto-detects linear layers.')
parser.add_argument('--lora_modules_to_save', type=str, nargs='+', default=None,
help='List of module names to fully train (not LoRA-adapted), '
'e.g., classifier heads that should be trained normally')
parser.add_argument('--save_lora_adapter', action='store_true',
help='Save the LoRA adapter separately in addition to the merged model')

# EWC arguments
parser.add_argument('--use_ewc', action='store_true',
help='Enable Elastic Weight Consolidation for continual learning')
parser.add_argument('--ewc_lambda', type=float, default=100.0,
help='Lambda parameter for EWC penalty (higher = more regularization)')
parser.add_argument('--ewc_fisher_samples', type=int, default=None,
help='Number of samples to use for Fisher information estimation (None = use all)')
parser.add_argument('--ewc_normalize_fisher', type=bool, default=True,
help='Whether to normalize Fisher information values')
parser.add_argument('--ewc_gamma', type=float, default=0.95,
help='Decay factor for Online EWC (0 < gamma < 1)')

args = parser.parse_args()

# Validate EWC arguments
if args.use_ewc and args.prev_data_path is None:
print("Warning: --use_ewc is set but --prev_data_path is not provided.")
print("EWC requires previous task data to compute Fisher information.")
print("Training will proceed without EWC.")

main(args)

RL Training Script

The GLiClass framework also supports Reinforcement learning, you can start training models using it with just a couple of changes to your training script.

Load pretrained model

This step leaves unchanged

import torch
from gliclass import GLiClassModel
from transformers import AutoTokenizer

device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')

model_name = "knowledgator/gliclass-small-v1.0"
model = GLiClassModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Initialize RL training components

from transformers import AutoModelForSequenceClassification
from gliclass.pipeline import ZeroShotClassificationPipeline

# Value model for advantage estimation
value_model = AutoModelForSequenceClassification.from_pretrained(model.config.encoder_model_name, num_labels=1)
value_model.resize_token_embeddings(len(tokenizer))

# Reference model for baseline comparisons
refrence_model = GLiClassModel.from_pretrained(model_name) # for most cases you may use the same model as reference model
reference_tokenizer = AutoTokenizer.from_pretrained(model_name)
reference_pipe = ZeroShotClassificationPipeline(refrence_model, reference_tokenizer,
classification_type='multi-label',
progress_bar=False, device=device)

Define Dataset

This step leaves unchanged (same dataset format, but now requires AugmentationConfig)

from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding, AugmentationConfig

data = [
{
"text": "A new machine learning platform automates complex data workflows but faces integration issues.",
"all_labels": ["AI", "automation", "data_analysis", "usability", "integration"],
"true_labels": ["AI", "integration", "automation"]
}
]

augment_config = AugmentationConfig(enabled=False)

train_dataset = GLiClassDataset(
data,
tokenizer,
augment_config,
label_to_description={},
max_length=1024,
problem_type='multi_label_classification',
)

# Data collator
data_collator = DataCollatorWithPadding(device=device)
Expected Output
Total labels:  5

Define functions for metrics computation

This step leaves unchanged

import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def compute_metrics_multi_label(p):
predictions, labels = p
labels = labels.reshape(-1)
predictions = predictions.reshape(-1)
preds = (predictions > 0.5).astype(int)
labels = np.where(labels>0.5, 1, 0)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}

Define reward function

def default_f1_reward(
probs: torch.Tensor,
actions: torch.Tensor,
original_targets: torch.Tensor,
valid_mask: torch.Tensor
) -> torch.Tensor:
"""
A variant that extracts list-of-indices sets and then calculates
the F1 score in a classical manner. Returns shape (N, 1).

Args:
probs: (N, T) Tensor of probabilities (not used here but left for interface consistency).
actions: (N, T) Tensor of predicted labels in {0, 1}.
original_targets: (N, T) Tensor of ground-truth labels in {0, 1}.
valid_mask: (N, T) Tensor indicating which positions are valid (1) vs. invalid (0).

Returns:
f1_scores: (N, 1) Tensor containing the F1 score for each row.
"""
N = actions.shape[0]
f1_scores = []

for i in range(N):
# Filter valid positions
valid_preds_i = actions[i] * valid_mask[i]
valid_targets_i = original_targets[i] * valid_mask[i]

# Get the set of indices where we predicted 1
predicted_set = set((valid_preds_i == 1).nonzero(as_tuple=True)[0].tolist())
# Get the set of indices where the ground truth is 1
target_set = set((valid_targets_i == 1).nonzero(as_tuple=True)[0].tolist())

# Compute intersection
intersection = predicted_set.intersection(target_set)

# Precision
if len(predicted_set) > 0:
precision = len(intersection) / len(predicted_set)
else:
precision = 0.0

# Recall
if len(target_set) > 0:
recall = len(intersection) / len(target_set)
else:
recall = 0.0

# F1 score
if (precision + recall) > 0:
f1 = 2 * precision * recall / (precision + recall)
else:
f1 = 0.0

f1_scores.append(f1)

# Convert list to tensor shape (N, 1)
f1_scores = torch.tensor(f1_scores, dtype=torch.float).unsqueeze(-1)
return f1_scores.detach().to(probs.device)

Train the model with RLTrainer

from gliclass.training import RLTrainerConfig, RLTrainer

training_args = RLTrainerConfig(
output_dir="my-awesome-rl-gliclass-model",
learning_rate=1e-5,
weight_decay=0.01,
others_lr=1e-5,
others_weight_decay=0.01,
lr_scheduler_type="cosine",
per_device_eval_batch_size=8,
num_train_epochs=1,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=False,
cliprange=0.2,
num_rl_iters=2
)

trainer = RLTrainer(
model=model,
value_model=value_model,
reference_model=reference_pipe,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics_multi_label,
reward_components={
'micro_f1': default_f1_reward,
},
)

trainer.train()

important

To avoid AttributeError during run in notebooks add following lines after initializing trainer:

trainer = RLTrainer(
model=model,
...
)

from transformers.utils.notebook import NotebookProgressCallback
trainer.remove_callback(NotebookProgressCallback)

trainer.train()

Full RL Training Script [source]

import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import numpy as np
import argparse
import json

from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification

import random
import torch

from gliclass import GLiClassModelConfig, GLiClassModel, ZeroShotClassificationPipeline
from gliclass.training import TrainingArguments, Trainer, RLTrainerConfig, RLTrainer
from gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset, AugmentationConfig
from gliclass.utils import default_f1_reward

def accuracy_reward(probs, actions, targets, valid_mask):
probs = probs * valid_mask
predicts = torch.argmax(probs, dim=-1)
true_labels = torch.argmax(targets, dim=-1)
correct = (predicts == true_labels).float().unsqueeze(1)
return correct

def recall_reward(
probs: torch.Tensor,
actions: torch.Tensor,
original_targets: torch.Tensor,
valid_mask: torch.Tensor
) -> torch.Tensor:
valid_preds = actions * valid_mask
valid_targets = original_targets * valid_mask

TP = torch.sum((valid_preds * valid_targets), dim=-1)
FN = torch.sum(((1 - valid_preds) * valid_targets), dim=-1)

eps = 1e-8
recall = TP / (TP + FN + eps)
return recall.detach().unsqueeze(1)

def main(args):
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')

if args.model_name is not None:
model = GLiClassModel.from_pretrained(args.model_name, focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
else:
tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)
encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)

if args.label_model_name is not None:
label_model_config = AutoConfig.from_pretrained(args.label_model_name)

glicalss_config = GLiClassModelConfig(
encoder_config=encoder_config,
encoder_model=args.encoder_model_name,
label_model_name=args.label_model_name,
label_model_config=label_model_config,
class_token_index=len(tokenizer),
text_token_index=len(tokenizer)+1,
pooling_strategy=args.pooler_type,
scorer_type=args.scorer_type,
use_lstm=args.use_lstm,
focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma,
labels_smoothing=args.labels_smoothing,
entropy_beta=args.entropy_beta,
kl_beta=args.kl_beta,
contrastive_loss_coef=args.contrastive_loss_coef,
normalize_features=args.normalize_features,
extract_text_features=args.extract_text_features,
architecture_type=args.architecture_type,
prompt_first=args.prompt_first,
squeeze_layers=args.squeeze_layers
)

glicalss_config.problem_type = args.problem_type

model = GLiClassModel(glicalss_config, from_pretrained=True)

if args.architecture_type in {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder'}:
new_words = ["<<LABEL>>", "<<SEP>>"]
tokenizer.add_tokens(new_words, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

if args.set_value_model:
value_model = AutoModelForSequenceClassification.from_pretrained(model.config.encoder_model_name, num_labels=1)
value_model.resize_token_embeddings(len(tokenizer))
else:
value_model = None

if args.reference_model is not None:
refrence_model = GLiClassModel.from_pretrained(args.reference_model)
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_model)
reference_pipe = ZeroShotClassificationPipeline(refrence_model, reference_tokenizer,
classification_type='multi-label',
progress_bar=False, device=device)
else:
reference_pipe = None

if args.label_model_name is not None:
labels_tokenizer = AutoTokenizer.from_pretrained(args.label_model_name)
else:
labels_tokenizer = None

model.to(device)

with open(args.data_path, 'r') as f:
data = json.load(f)[:]
init_ld = len(data)*1

print('Dataset size:', len(data))
random.shuffle(data)
print('Dataset is shuffled...')

train_data = data[:int(len(data)*0.9)]
test_data = data[int(len(data)*0.9):]

print('Dataset is splitted...')

augment_config = AugmentationConfig(enabled=False)

train_dataset = GLiClassDataset(train_data, tokenizer, augment_config,
{}, args.max_length,
args.problem_type, args.architecture_type,
args.prompt_first, labels_tokenizer=labels_tokenizer)
test_dataset = GLiClassDataset(test_data, tokenizer, augment_config,
{}, args.max_length, args.problem_type,
args.architecture_type, args.prompt_first,
labels_tokenizer = labels_tokenizer)

data_collator = DataCollatorWithPadding(device=device)

training_args = RLTrainerConfig(
output_dir=args.save_path,
learning_rate=args.encoder_lr,
weight_decay=args.encoder_weight_decay,
others_lr=args.others_lr,
others_weight_decay=args.others_weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
warmup_ratio=args.warmup_ratio,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.num_epochs,
evaluation_strategy="epoch",
save_steps = args.save_steps,
save_total_limit=args.save_total_limit,
dataloader_num_workers = args.num_workers,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=args.fp16,
cliprange=args.clip_range,
num_rl_iters=args.num_rl_iters
)

trainer = RLTrainer(
model=model,
value_model=value_model,
reference_model=reference_pipe,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
reward_components={
'micro_f1': default_f1_reward,
},
)
trainer.train()

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default= "knowledgator/gliclass-modern-base-v2.0-init")
parser.add_argument('--encoder_model_name', type=str, default = 'microsoft/deberta-v3-small')
parser.add_argument('--label_model_name', type=str, default = "BAAI/bge-small-en-v1.5")
parser.add_argument('--reference_model', type=str, default = None)
parser.add_argument('--set_value_model', type=bool, default = True)
parser.add_argument('--save_path', type=str, default = 'models/')
parser.add_argument('--data_path', type=str, default = 'data/zero-cats.json')
parser.add_argument('--problem_type', type=str, default='multi_label_classification')
parser.add_argument('--pooler_type', type=str, default='avg')
parser.add_argument('--scorer_type', type=str, default='simple')
parser.add_argument('--architecture_type', type=str, default='uni-encoder')
parser.add_argument('--normalize_features', type=bool, default=False)
parser.add_argument('--extract_text_features', type=bool, default=False)
parser.add_argument('--prompt_first', type=bool, default=True)
parser.add_argument('--use_lstm', type=bool, default=False)
parser.add_argument('--squeeze_layers', type=bool, default=False)
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--encoder_lr', type=float, default=2e-6)
parser.add_argument('--others_lr', type=float, default=3e-6)
parser.add_argument('--encoder_weight_decay', type=float, default=0.01)
parser.add_argument('--others_weight_decay', type=float, default=0.01)
parser.add_argument('--warmup_ratio', type=float, default=0.05)
parser.add_argument('--lr_scheduler_type', type=str, default='linear')
parser.add_argument('--focal_loss_alpha', type=float, default=-1)
parser.add_argument('--focal_loss_gamma', type=float, default=-1)
parser.add_argument('--labels_smoothing', type=float, default=-1)
parser.add_argument('--entropy_beta', type=float, default=-1)
parser.add_argument('--kl_beta', type=float, default=0.1)
parser.add_argument('--clip_range', type=float, default=0.2)
parser.add_argument('--num_rl_iters', type=int, default=2)
parser.add_argument('--contrastive_loss_coef', type=float, default=0.)
parser.add_argument('--max_length', type=int, default=2048)
parser.add_argument('--save_steps', type=int, default=300)
parser.add_argument('--save_total_limit', type=int, default=3)
parser.add_argument('--num_workers', type=int, default=12)
parser.add_argument('--fp16', type=bool, default=False)
args = parser.parse_args()

main(args)
IMPORTANT

Evaluation

Once you have trained your model, you will most likely want to evaluate it. We have already prepared a test_gliclass.py[source] script for you that will help you to evaluate the model on 13 different zero-shot datasets.

Enter the repo and activate yout env

cd GLiClass
source venv/bin/activate
note

If you don't have gliclass framework installed, please check out our installation guide first.

Run evaluation script

python test_gliclass.py --model knowledgator/gliclass-base-v1.0 --api_key YOR_KEY_IF_REQUIERED