Training
Quickstart
Installation
To install GLiClass, run the following command:
pip install gliclass
Base Training Script
Load pretrained model
import torch
from gliclass import GLiClassModel
from transformers import AutoTokenizer
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
model_name = "knowledgator/gliclass-small-v1.0"
model = GLiClassModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
Define Dataset
The dataset for training GLiClass must contain text field, which represents text for classification, all_labels field, whic stands for all labels to classify from and true_labels which will represent the correct labels for given text. For more info about datasets please visit our datasets page
from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding, AugmentationConfig
data = [
{
"text": "A new machine learning platform automates complex data workflows but faces integration issues.",
"all_labels": ["AI", "automation", "data_analysis", "usability", "integration"],
"true_labels": ["AI", "integration", "automation"]
}
]
# Configure data augmentation (disabled by default)
augment_config = AugmentationConfig(enabled=False)
train_dataset = GLiClassDataset(
data,
tokenizer,
augment_config,
label_to_description={}, # optional: map labels to descriptions
max_length=1024,
problem_type='multi_label_classification',
)
# Data collator
data_collator = DataCollatorWithPadding(device=device)
Enable augmentation to improve training robustness by randomly modifying labels and text during training:
augment_config = AugmentationConfig(
enabled=True,
random_label_removal_prob=0.1, # randomly remove labels
random_label_addition_prob=0.1, # randomly add labels
random_text_addition_prob=0.0, # randomly add text
random_add_description_prob=0.1, # add label descriptions
random_add_synonyms_prob=0.0, # add synonym variations
random_add_examples_prob=0.0, # add example texts
max_num_examples=5, # max examples to add
)
Expected Output
Total labels: 5
Define functions for metrics computation
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
def compute_metrics_multi_label(p):
predictions, labels = p
labels = labels.reshape(-1)
predictions = predictions.reshape(-1)
preds = (predictions > 0.5).astype(int)
labels = np.where(labels>0.5, 1, 0)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
Train the model
from gliclass.training import TrainingArguments, Trainer
# Training arguments
training_args = TrainingArguments(
output_dir="my-awesome-gliclass-model",
learning_rate=1e-5,
weight_decay=0.01,
others_lr=1e-5,
others_weight_decay=0.01,
lr_scheduler_type="cosine",
per_device_eval_batch_size=8,
num_train_epochs=1,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=False, # Set to True if you want to use mixed precision training
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics_multi_label
)
# Run training
trainer.train()
Full Base Training Script [source]
The following script could be used both for training from scratch and fine-tuning pretrained model. It supports data augmentation, LoRA parameter-efficient fine-tuning, and EWC continual learning:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import numpy as np
import argparse
import json
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import transformers
from transformers import AutoTokenizer, AutoConfig
from torch.utils.data import WeightedRandomSampler
from packaging import version
from peft import LoraConfig, get_peft_model, TaskType
import random
import torch
from gliclass import GLiClassModelConfig, GLiClassModel
from gliclass.training import TrainingArguments, Trainer
from gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset, AugmentationConfig
class CustomTrainer(Trainer):
"""Trainer with weighted random sampling support."""
def __init__(self, *args, use_weighted_sampling=False, **kwargs):
super().__init__(*args, **kwargs)
self.use_weighted_sampling = use_weighted_sampling
def _get_train_sampler(self, train_dataset) -> torch.utils.data.Sampler:
if not self.use_weighted_sampling:
return super()._get_train_sampler()
weights = train_dataset.get_diversity()
return WeightedRandomSampler(
weights=weights,
num_samples=len(train_dataset),
replacement=True
)
def compute_metrics(p, problem_type='multi_label_classification'):
"""Compute evaluation metrics.
Args:
p: Predictions tuple (predictions, labels)
problem_type: Type of classification problem
Returns:
Dictionary of metrics
"""
predictions, labels = p
labels = labels.reshape(-1)
if problem_type == 'single_label_classification':
preds = np.argmax(predictions, axis=1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
elif problem_type == 'multi_label_classification':
predictions = predictions.reshape(-1)
preds = (predictions > 0.5).astype(int)
labels = np.where(labels > 0.5, 1, 0)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
else:
raise NotImplementedError(f"{problem_type} is not implemented.")
def load_dataset(data_path: str) -> list:
"""Load dataset from JSON or JSONL file.
Args:
data_path: Path to JSON or JSONL data file
Returns:
List of data samples
"""
with open(data_path, 'r') as f:
if data_path.endswith('.jsonl'):
data = [json.loads(line) for line in f if line.strip()]
else:
data = json.load(f)
return data
def get_lora_target_modules(model, args):
"""Determine target modules for LoRA based on model architecture.
If --lora_target_modules is provided, use those. Otherwise, auto-detect
all linear layers (excluding the final classification head).
Args:
model: The GLiClassModel
args: Parsed arguments
Returns:
List of module name patterns to target
"""
if args.lora_target_modules:
return args.lora_target_modules
# Collect short names grouped by whether they are Linear or not
linear_names = set()
non_linear_names = set()
for name, module in model.named_modules():
short_name = name.split('.')[-1]
if isinstance(module, torch.nn.Linear):
linear_names.add(short_name)
else:
non_linear_names.add(short_name)
# Only keep names that EXCLUSIVELY refer to Linear layers
# (avoid names like "layer" or "output" that also match container modules)
target_modules = linear_names - non_linear_names
# Remove common classification head names to avoid adapting the output layer
head_names = {'classifier', 'score', 'out_proj', 'dense_out', 'head'}
target_modules -= head_names
target_modules = sorted(target_modules)
print(f"Auto-detected LoRA target modules: {target_modules}")
return target_modules
def apply_lora(model, args):
"""Apply LoRA adapters to the model.
Args:
model: The GLiClassModel
args: Parsed arguments containing LoRA config
Returns:
PeftModel with LoRA adapters applied
"""
target_modules = get_lora_target_modules(model, args)
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=target_modules,
lora_dropout=args.lora_dropout,
bias=args.lora_bias,
modules_to_save=args.lora_modules_to_save,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model
def main(args):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# Load or create model
if args.model_name is not None:
model = GLiClassModel.from_pretrained(
args.model_name,
focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma,
focal_loss_reduction=args.focal_loss_reduction
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
else:
tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)
encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)
label_model_config = None
if args.label_model_name is not None:
label_model_config = AutoConfig.from_pretrained(args.label_model_name)
glicalss_config = GLiClassModelConfig(
encoder_config=encoder_config,
encoder_model=args.encoder_model_name,
label_model_name=args.label_model_name,
label_model_config=label_model_config,
class_token_index=len(tokenizer),
text_token_index=len(tokenizer)+1,
example_token_index=len(tokenizer)+2,
pooling_strategy=args.pooler_type,
class_token_pooling=args.class_token_pooling,
scorer_type=args.scorer_type,
use_lstm=args.use_lstm,
focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma,
focal_loss_reduction=args.focal_loss_reduction,
contrastive_loss_coef=args.contrastive_loss_coef,
normalize_features=args.normalize_features,
extract_text_features=args.extract_text_features,
architecture_type=args.architecture_type,
prompt_first=args.prompt_first,
squeeze_layers=args.squeeze_layers,
layer_wise=args.layer_wise,
encoder_layer_id=args.encoder_layer_id,
shuffle_labels=args.shuffle_labels,
dropout=args.dropout,
use_segment_embeddings=args.use_segment_embeddings,
)
model = GLiClassModel(glicalss_config, from_pretrained=True).to(dtype=torch.float32)
if args.architecture_type in {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder', 'encoder-decoder-cls'}:
new_words = ["<<LABEL>>", "<<SEP>>", "<<EXAMPLE>>"]
tokenizer.add_tokens(new_words, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
model.to(device)
# Apply LoRA if enabled
if args.use_lora:
print("\n--- Applying LoRA adapters ---")
model = apply_lora(model, args)
print("--- LoRA adapters applied ---\n")
# Get labels tokenizer if needed
if model.config.label_model_name is not None:
labels_tokenizer = AutoTokenizer.from_pretrained(model.config.label_model_name)
else:
labels_tokenizer = None
model.config.problem_type = args.problem_type
# Load current training data
data = load_dataset(args.data_path)
data = [item for item in data if len(item['text'])//4<2048]
print(f'Dataset size: {len(data)}')
random.shuffle(data)
print('Dataset is shuffled...')
train_data = data[:int(len(data) * 0.9)]
test_data = data[int(len(data) * 0.9):]
print('Dataset is splitted...')
# Create augmentation config with all parameters
augment_config = AugmentationConfig(
enabled=args.enable_augmentation,
random_label_removal_prob=args.random_label_removal_prob,
random_label_addition_prob=args.random_label_addition_prob,
random_text_addition_prob=args.random_text_addition_prob,
random_add_description_prob=args.random_add_description_prob,
random_add_synonyms_prob=args.random_add_synonyms_prob,
random_add_examples_prob=args.random_add_examples_prob,
max_num_examples=args.max_num_examples
)
if args.labels_desc_path is not None:
labels_descriptions = load_dataset(args.labels_desc_path)
label_to_description = {item.get("label"): item for item in labels_descriptions}
else:
label_to_description = {}
train_dataset = GLiClassDataset(train_data, tokenizer, augment_config,
label_to_description, args.max_length,
args.problem_type, args.architecture_type,
args.prompt_first, labels_tokenizer=labels_tokenizer)
# Disable augmentation for test dataset
test_augment_config = AugmentationConfig(enabled=False)
test_dataset = GLiClassDataset(test_data, tokenizer, test_augment_config,
label_to_description,
args.max_length, args.problem_type,
args.architecture_type, args.prompt_first,
labels_tokenizer = labels_tokenizer)
# Load previous dataset for EWC if provided
prev_dataset = None
if args.use_ewc and args.prev_data_path is not None:
print(f'Loading previous dataset for EWC from: {args.prev_data_path}')
prev_data = load_dataset(args.prev_data_path)
print(f'Previous dataset size: {len(prev_data)}')
# Use a subset if specified
if args.ewc_fisher_samples is not None and args.ewc_fisher_samples < len(prev_data):
random.shuffle(prev_data)
prev_data = prev_data[:args.ewc_fisher_samples]
print(f'Using {len(prev_data)} samples for Fisher estimation')
prev_dataset = GLiClassDataset(prev_data, tokenizer, test_augment_config,
label_to_description,
args.max_length, args.problem_type,
args.architecture_type, args.prompt_first,
labels_tokenizer = labels_tokenizer)
data_collator = DataCollatorWithPadding(device=device)
# Create training arguments with EWC parameters
training_args = TrainingArguments(
output_dir=args.save_path,
learning_rate=args.encoder_lr,
weight_decay=args.encoder_weight_decay,
others_lr=args.others_lr,
others_weight_decay=args.others_weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
warmup_ratio=args.warmup_ratio,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
num_train_epochs=args.num_epochs,
save_steps=args.save_steps,
save_total_limit=args.save_total_limit,
dataloader_num_workers=args.num_workers,
logging_steps=100,
use_cpu=False,
report_to="none",
fp16=args.fp16,
save_only_model=True,
use_ewc=args.use_ewc,
ewc_lambda=args.ewc_lambda,
ewc_fisher_samples=args.ewc_fisher_samples,
ewc_normalize_fisher=args.ewc_normalize_fisher,
ewc_gamma=args.ewc_gamma,
)
# Create compute_metrics function with problem_type closure
def compute_metrics_fn(p):
return compute_metrics(p, args.problem_type)
# Create trainer with EWC support
# Handle version differences between transformers v4 and v5
trainer_kwargs = {
"model": model,
"args": training_args,
"train_dataset": train_dataset,
"eval_dataset": test_dataset,
"data_collator": data_collator,
"compute_metrics": compute_metrics_fn,
"prev_dataset": prev_dataset, # Pass previous dataset for EWC
}
if version.parse(transformers.__version__) < version.parse("5.0.0"):
trainer_kwargs["tokenizer"] = tokenizer
else:
trainer_kwargs["processing_class"] = tokenizer
trainer = CustomTrainer(**trainer_kwargs)
# Print EWC status
if args.use_ewc:
if args.prev_data_path is not None:
print(f'\nEWC enabled with lambda={args.ewc_lambda}')
else:
print('\nWarning: EWC is enabled but no previous data path provided. EWC will not be used.')
trainer.train()
# Save final model
final_output_dir = os.path.join(args.save_path, 'final_model')
if args.use_lora:
print("\n--- Merging LoRA weights into base model ---")
# Merge LoRA adapters into the base model weights
merged_model = model.merge_and_unload()
print("--- LoRA weights merged successfully ---")
# Save the merged model as a standard model (no adapter artifacts)
merged_model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
print(f'Merged model saved to {final_output_dir}')
# Optionally save the unmerged LoRA adapter separately
if args.save_lora_adapter:
adapter_dir = os.path.join(args.save_path, 'lora_adapter')
model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
print(f'LoRA adapter also saved separately to {adapter_dir}')
else:
model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
print(f'Final model saved to {final_output_dir}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train GLiClass model with optional LoRA and EWC for continual learning')
# Model arguments
parser.add_argument('--model_name', type=str, default=None,
help='Pretrained model name or path')
parser.add_argument('--encoder_model_name', type=str, default='microsoft/deberta-v3-base',
help='Encoder model name')
parser.add_argument('--label_model_name', type=str, default="BAAI/bge-small-en-v1.5",
help='Label model name')
# Path arguments
parser.add_argument('--save_path', type=str, default='models/',
help='Path to save trained model')
parser.add_argument('--data_path', type=str, default='data/zero-cats.json',
help='Path to training data JSON or JSONL file')
parser.add_argument('--prev_data_path', type=str, default=None,
help='Path to previous task data for EWC (required if use_ewc=True)')
parser.add_argument('--labels_desc_path', type=str, default=None)
# Model architecture arguments
parser.add_argument('--problem_type', type=str, default='multi_label_classification',
choices=['single_label_classification', 'multi_label_classification'])
parser.add_argument('--pooler_type', type=str, default='first')
parser.add_argument('--scorer_type', type=str, default='mlp')
parser.add_argument('--architecture_type', type=str, default='uni-encoder')
parser.add_argument('--class_token_pooling', type=str, default='average')
parser.add_argument('--normalize_features', type=bool, default=False)
parser.add_argument('--extract_text_features', type=bool, default=False)
parser.add_argument('--prompt_first', type=bool, default=True)
parser.add_argument('--use_lstm', type=bool, default=False)
parser.add_argument('--squeeze_layers', type=bool, default=False)
parser.add_argument('--layer_wise', type=bool, default=False)
parser.add_argument('--encoder_layer_id', type=int, default=-1)
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--shuffle_labels', type=bool, default=True)
parser.add_argument('--use_segment_embeddings', type=bool, default=True)
# Training arguments
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--gradient_accumulation_steps', type=int, default=2)
parser.add_argument('--encoder_lr', type=float, default=1e-5)
parser.add_argument('--others_lr', type=float, default=1e-5)
parser.add_argument('--encoder_weight_decay', type=float, default=0.01)
parser.add_argument('--others_weight_decay', type=float, default=0.01)
parser.add_argument('--warmup_ratio', type=float, default=0.05)
parser.add_argument('--lr_scheduler_type', type=str, default='cosine')
parser.add_argument('--max_length', type=int, default=4096)
parser.add_argument('--save_steps', type=int, default=1000)
parser.add_argument('--save_total_limit', type=int, default=3)
parser.add_argument('--num_workers', type=int, default=12)
parser.add_argument('--fp16', type=bool, default=False)
# Augmentation parameters
parser.add_argument('--enable_augmentation', type=bool, default=False)
parser.add_argument('--random_label_removal_prob', type=float, default=0.0)
parser.add_argument('--random_label_addition_prob', type=float, default=0.0)
parser.add_argument('--random_text_addition_prob', type=float, default=0.0)
parser.add_argument('--random_add_description_prob', type=float, default=0.1)
parser.add_argument('--random_add_synonyms_prob', type=float, default=0.0)
parser.add_argument('--random_add_examples_prob', type=float, default=0.0)
parser.add_argument('--max_num_examples', type=int, default=5)
# Loss arguments
parser.add_argument('--focal_loss_alpha', type=float, default=0.8)
parser.add_argument('--focal_loss_gamma', type=float, default=-1)
parser.add_argument('--focal_loss_reduction', type=str, default='none',
choices=['none', 'mean', 'sum'])
parser.add_argument('--contrastive_loss_coef', type=float, default=0.)
# LoRA arguments
parser.add_argument('--use_lora', action='store_true',
help='Enable LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning')
parser.add_argument('--lora_r', type=int, default=1024,
help='LoRA rank (dimensionality of low-rank matrices)')
parser.add_argument('--lora_alpha', type=int, default=2048,
help='LoRA alpha (scaling factor, effective scaling = alpha/r)')
parser.add_argument('--lora_dropout', type=float, default=0.05,
help='Dropout probability for LoRA layers')
parser.add_argument('--lora_bias', type=str, default='none',
choices=['none', 'all', 'lora_only'],
help='Which bias parameters to train: none, all, or lora_only')
parser.add_argument('--lora_target_modules', type=str, nargs='+', default=None,
help='List of module names to apply LoRA to (e.g., query_proj value_proj). '
'If not specified, auto-detects linear layers.')
parser.add_argument('--lora_modules_to_save', type=str, nargs='+', default=None,
help='List of module names to fully train (not LoRA-adapted), '
'e.g., classifier heads that should be trained normally')
parser.add_argument('--save_lora_adapter', action='store_true',
help='Save the LoRA adapter separately in addition to the merged model')
# EWC arguments
parser.add_argument('--use_ewc', action='store_true',
help='Enable Elastic Weight Consolidation for continual learning')
parser.add_argument('--ewc_lambda', type=float, default=100.0,
help='Lambda parameter for EWC penalty (higher = more regularization)')
parser.add_argument('--ewc_fisher_samples', type=int, default=None,
help='Number of samples to use for Fisher information estimation (None = use all)')
parser.add_argument('--ewc_normalize_fisher', type=bool, default=True,
help='Whether to normalize Fisher information values')
parser.add_argument('--ewc_gamma', type=float, default=0.95,
help='Decay factor for Online EWC (0 < gamma < 1)')
args = parser.parse_args()
# Validate EWC arguments
if args.use_ewc and args.prev_data_path is None:
print("Warning: --use_ewc is set but --prev_data_path is not provided.")
print("EWC requires previous task data to compute Fisher information.")
print("Training will proceed without EWC.")
main(args)
RL Training Script
The GLiClass framework also supports Reinforcement learning, you can start training models using it with just a couple of changes to your training script.
Load pretrained model
This step leaves unchanged
import torch
from gliclass import GLiClassModel
from transformers import AutoTokenizer
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
model_name = "knowledgator/gliclass-small-v1.0"
model = GLiClassModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
Initialize RL training components
from transformers import AutoModelForSequenceClassification
from gliclass.pipeline import ZeroShotClassificationPipeline
# Value model for advantage estimation
value_model = AutoModelForSequenceClassification.from_pretrained(model.config.encoder_model_name, num_labels=1)
value_model.resize_token_embeddings(len(tokenizer))
# Reference model for baseline comparisons
refrence_model = GLiClassModel.from_pretrained(model_name) # for most cases you may use the same model as reference model
reference_tokenizer = AutoTokenizer.from_pretrained(model_name)
reference_pipe = ZeroShotClassificationPipeline(refrence_model, reference_tokenizer,
classification_type='multi-label',
progress_bar=False, device=device)
Define Dataset
This step leaves unchanged (same dataset format, but now requires AugmentationConfig)
from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding, AugmentationConfig
data = [
{
"text": "A new machine learning platform automates complex data workflows but faces integration issues.",
"all_labels": ["AI", "automation", "data_analysis", "usability", "integration"],
"true_labels": ["AI", "integration", "automation"]
}
]
augment_config = AugmentationConfig(enabled=False)
train_dataset = GLiClassDataset(
data,
tokenizer,
augment_config,
label_to_description={},
max_length=1024,
problem_type='multi_label_classification',
)
# Data collator
data_collator = DataCollatorWithPadding(device=device)
Expected Output
Total labels: 5
Define functions for metrics computation
This step leaves unchanged
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
def compute_metrics_multi_label(p):
predictions, labels = p
labels = labels.reshape(-1)
predictions = predictions.reshape(-1)
preds = (predictions > 0.5).astype(int)
labels = np.where(labels>0.5, 1, 0)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
Define reward function
def default_f1_reward(
probs: torch.Tensor,
actions: torch.Tensor,
original_targets: torch.Tensor,
valid_mask: torch.Tensor
) -> torch.Tensor:
"""
A variant that extracts list-of-indices sets and then calculates
the F1 score in a classical manner. Returns shape (N, 1).
Args:
probs: (N, T) Tensor of probabilities (not used here but left for interface consistency).
actions: (N, T) Tensor of predicted labels in {0, 1}.
original_targets: (N, T) Tensor of ground-truth labels in {0, 1}.
valid_mask: (N, T) Tensor indicating which positions are valid (1) vs. invalid (0).
Returns:
f1_scores: (N, 1) Tensor containing the F1 score for each row.
"""
N = actions.shape[0]
f1_scores = []
for i in range(N):
# Filter valid positions
valid_preds_i = actions[i] * valid_mask[i]
valid_targets_i = original_targets[i] * valid_mask[i]
# Get the set of indices where we predicted 1
predicted_set = set((valid_preds_i == 1).nonzero(as_tuple=True)[0].tolist())
# Get the set of indices where the ground truth is 1
target_set = set((valid_targets_i == 1).nonzero(as_tuple=True)[0].tolist())
# Compute intersection
intersection = predicted_set.intersection(target_set)
# Precision
if len(predicted_set) > 0:
precision = len(intersection) / len(predicted_set)
else:
precision = 0.0
# Recall
if len(target_set) > 0:
recall = len(intersection) / len(target_set)
else:
recall = 0.0
# F1 score
if (precision + recall) > 0:
f1 = 2 * precision * recall / (precision + recall)
else:
f1 = 0.0
f1_scores.append(f1)
# Convert list to tensor shape (N, 1)
f1_scores = torch.tensor(f1_scores, dtype=torch.float).unsqueeze(-1)
return f1_scores.detach().to(probs.device)
Train the model with RLTrainer
from gliclass.training import RLTrainerConfig, RLTrainer
training_args = RLTrainerConfig(
output_dir="my-awesome-rl-gliclass-model",
learning_rate=1e-5,
weight_decay=0.01,
others_lr=1e-5,
others_weight_decay=0.01,
lr_scheduler_type="cosine",
per_device_eval_batch_size=8,
num_train_epochs=1,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=False,
cliprange=0.2,
num_rl_iters=2
)
trainer = RLTrainer(
model=model,
value_model=value_model,
reference_model=reference_pipe,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics_multi_label,
reward_components={
'micro_f1': default_f1_reward,
},
)
trainer.train()
To avoid AttributeError during run in notebooks add following lines after initializing trainer:
trainer = RLTrainer(
model=model,
...
)
from transformers.utils.notebook import NotebookProgressCallback
trainer.remove_callback(NotebookProgressCallback)
trainer.train()
Full RL Training Script [source]
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import numpy as np
import argparse
import json
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
import random
import torch
from gliclass import GLiClassModelConfig, GLiClassModel, ZeroShotClassificationPipeline
from gliclass.training import TrainingArguments, Trainer, RLTrainerConfig, RLTrainer
from gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset, AugmentationConfig
from gliclass.utils import default_f1_reward
def accuracy_reward(probs, actions, targets, valid_mask):
probs = probs * valid_mask
predicts = torch.argmax(probs, dim=-1)
true_labels = torch.argmax(targets, dim=-1)
correct = (predicts == true_labels).float().unsqueeze(1)
return correct
def recall_reward(
probs: torch.Tensor,
actions: torch.Tensor,
original_targets: torch.Tensor,
valid_mask: torch.Tensor
) -> torch.Tensor:
valid_preds = actions * valid_mask
valid_targets = original_targets * valid_mask
TP = torch.sum((valid_preds * valid_targets), dim=-1)
FN = torch.sum(((1 - valid_preds) * valid_targets), dim=-1)
eps = 1e-8
recall = TP / (TP + FN + eps)
return recall.detach().unsqueeze(1)
def main(args):
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
if args.model_name is not None:
model = GLiClassModel.from_pretrained(args.model_name, focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
else:
tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)
encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)
if args.label_model_name is not None:
label_model_config = AutoConfig.from_pretrained(args.label_model_name)
glicalss_config = GLiClassModelConfig(
encoder_config=encoder_config,
encoder_model=args.encoder_model_name,
label_model_name=args.label_model_name,
label_model_config=label_model_config,
class_token_index=len(tokenizer),
text_token_index=len(tokenizer)+1,
pooling_strategy=args.pooler_type,
scorer_type=args.scorer_type,
use_lstm=args.use_lstm,
focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma,
labels_smoothing=args.labels_smoothing,
entropy_beta=args.entropy_beta,
kl_beta=args.kl_beta,
contrastive_loss_coef=args.contrastive_loss_coef,
normalize_features=args.normalize_features,
extract_text_features=args.extract_text_features,
architecture_type=args.architecture_type,
prompt_first=args.prompt_first,
squeeze_layers=args.squeeze_layers
)
glicalss_config.problem_type = args.problem_type
model = GLiClassModel(glicalss_config, from_pretrained=True)
if args.architecture_type in {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder'}:
new_words = ["<<LABEL>>", "<<SEP>>"]
tokenizer.add_tokens(new_words, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
if args.set_value_model:
value_model = AutoModelForSequenceClassification.from_pretrained(model.config.encoder_model_name, num_labels=1)
value_model.resize_token_embeddings(len(tokenizer))
else:
value_model = None
if args.reference_model is not None:
refrence_model = GLiClassModel.from_pretrained(args.reference_model)
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_model)
reference_pipe = ZeroShotClassificationPipeline(refrence_model, reference_tokenizer,
classification_type='multi-label',
progress_bar=False, device=device)
else:
reference_pipe = None
if args.label_model_name is not None:
labels_tokenizer = AutoTokenizer.from_pretrained(args.label_model_name)
else:
labels_tokenizer = None
model.to(device)
with open(args.data_path, 'r') as f:
data = json.load(f)[:]
init_ld = len(data)*1
print('Dataset size:', len(data))
random.shuffle(data)
print('Dataset is shuffled...')
train_data = data[:int(len(data)*0.9)]
test_data = data[int(len(data)*0.9):]
print('Dataset is splitted...')
augment_config = AugmentationConfig(enabled=False)
train_dataset = GLiClassDataset(train_data, tokenizer, augment_config,
{}, args.max_length,
args.problem_type, args.architecture_type,
args.prompt_first, labels_tokenizer=labels_tokenizer)
test_dataset = GLiClassDataset(test_data, tokenizer, augment_config,
{}, args.max_length, args.problem_type,
args.architecture_type, args.prompt_first,
labels_tokenizer = labels_tokenizer)
data_collator = DataCollatorWithPadding(device=device)
training_args = RLTrainerConfig(
output_dir=args.save_path,
learning_rate=args.encoder_lr,
weight_decay=args.encoder_weight_decay,
others_lr=args.others_lr,
others_weight_decay=args.others_weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
warmup_ratio=args.warmup_ratio,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.num_epochs,
evaluation_strategy="epoch",
save_steps = args.save_steps,
save_total_limit=args.save_total_limit,
dataloader_num_workers = args.num_workers,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=args.fp16,
cliprange=args.clip_range,
num_rl_iters=args.num_rl_iters
)
trainer = RLTrainer(
model=model,
value_model=value_model,
reference_model=reference_pipe,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
reward_components={
'micro_f1': default_f1_reward,
},
)
trainer.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default= "knowledgator/gliclass-modern-base-v2.0-init")
parser.add_argument('--encoder_model_name', type=str, default = 'microsoft/deberta-v3-small')
parser.add_argument('--label_model_name', type=str, default = "BAAI/bge-small-en-v1.5")
parser.add_argument('--reference_model', type=str, default = None)
parser.add_argument('--set_value_model', type=bool, default = True)
parser.add_argument('--save_path', type=str, default = 'models/')
parser.add_argument('--data_path', type=str, default = 'data/zero-cats.json')
parser.add_argument('--problem_type', type=str, default='multi_label_classification')
parser.add_argument('--pooler_type', type=str, default='avg')
parser.add_argument('--scorer_type', type=str, default='simple')
parser.add_argument('--architecture_type', type=str, default='uni-encoder')
parser.add_argument('--normalize_features', type=bool, default=False)
parser.add_argument('--extract_text_features', type=bool, default=False)
parser.add_argument('--prompt_first', type=bool, default=True)
parser.add_argument('--use_lstm', type=bool, default=False)
parser.add_argument('--squeeze_layers', type=bool, default=False)
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--encoder_lr', type=float, default=2e-6)
parser.add_argument('--others_lr', type=float, default=3e-6)
parser.add_argument('--encoder_weight_decay', type=float, default=0.01)
parser.add_argument('--others_weight_decay', type=float, default=0.01)
parser.add_argument('--warmup_ratio', type=float, default=0.05)
parser.add_argument('--lr_scheduler_type', type=str, default='linear')
parser.add_argument('--focal_loss_alpha', type=float, default=-1)
parser.add_argument('--focal_loss_gamma', type=float, default=-1)
parser.add_argument('--labels_smoothing', type=float, default=-1)
parser.add_argument('--entropy_beta', type=float, default=-1)
parser.add_argument('--kl_beta', type=float, default=0.1)
parser.add_argument('--clip_range', type=float, default=0.2)
parser.add_argument('--num_rl_iters', type=int, default=2)
parser.add_argument('--contrastive_loss_coef', type=float, default=0.)
parser.add_argument('--max_length', type=int, default=2048)
parser.add_argument('--save_steps', type=int, default=300)
parser.add_argument('--save_total_limit', type=int, default=3)
parser.add_argument('--num_workers', type=int, default=12)
parser.add_argument('--fp16', type=bool, default=False)
args = parser.parse_args()
main(args)
Evaluation
Once you have trained your model, you will most likely want to evaluate it. We have already prepared a test_gliclass.py[source] script for you that will help you to evaluate the model on 13 different zero-shot datasets.
Enter the repo and activate yout env
cd GLiClass
source venv/bin/activate
If you don't have gliclass framework installed, please check out our installation guide first.
Run evaluation script
python test_gliclass.py --model knowledgator/gliclass-base-v1.0 --api_key YOR_KEY_IF_REQUIERED