GliZNet: Generalized Zero-Shot Text Classification

Published:

GliZNet: Zero-Shot Multi-Label Text Classification

GliZNet (Generalized Label-Integrated Zero-shot Network) is an innovative architecture for zero-shot text classification that embeds labels directly in the input sequence alongside the text, enabling efficient multi-label classification through a single forward pass. This project advances zero-shot learning by combining supervised contrastive learning with label repulsion techniques.

Project Overview

GliZNet represents a significant advancement in zero-shot text classification by integrating labels directly into the transformer input sequence using special [LAB] tokens. Unlike traditional approaches that require separate forward passes for each label, GliZNet processes text and all candidate labels simultaneously, achieving 10x faster inference while maintaining high accuracy. The model uses ModernBERT-base as its backbone and employs a sophisticated loss function combining supervised contrastive learning, label repulsion, and binary cross-entropy.

Project Objectives

  1. Develop Efficient Zero-Shot Classification:
    • Create a single-pass architecture that processes text and all labels simultaneously
    • Achieve significant speed improvements over traditional cross-encoder approaches
  2. Advance Loss Function Design:
    • Implement supervised contrastive learning for positive label clustering
    • Develop label repulsion mechanisms to prevent embedding collapse
    • Combine multiple loss functions for robust training
  3. Enable Flexible Classification:
    • Support both multi-label and multi-class classification scenarios
    • Provide configurable similarity metrics and projection layers
    • Ensure compatibility with various transformer backbones

Architecture & Methodology

The GliZNet architecture follows a streamlined approach that integrates labels directly into the transformer input:

Input Format

[CLS] text tokens [LAB] label1 [LAB] label2 ... [SEP]

Processing Pipeline

  1. Joint Encoding: Text and labels are encoded together in a single forward pass
  2. Separate Projections: Independent projection layers for text (CLS token) and labels (mean pooling)
  3. Similarity Computation: Cosine similarity with learned temperature scaling
  4. Multi-Objective Training: Combined loss functions for optimal performance

Key Innovations

  • Label Integration: Special [LAB] tokens allow labels to be embedded in the input sequence
  • Supervised Contrastive Loss: Encourages positive labels to cluster together while pushing negative labels apart
  • Label Repulsion: Prevents different labels from collapsing into similar embeddings within the same sample
  • Decoupled BCE: Auxiliary classification loss with independent temperature scaling

Training Configuration

  • Similarity Metric: Cosine similarity with learned temperature (initialized at 2.0)
  • Loss Weights: SupCon (1.0), Label Repulsion (0.1), BCE (1.0)
  • Repulsion Threshold: 0.3 to control label separation
  • Backbone: ModernBERT-base with smart embedding resizing

Tools and Technologies

  • Core Framework:
    • Transformers: HuggingFace transformers library for model implementation
    • PyTorch: Deep learning framework for custom architectures and training
    • ModernBERT: State-of-the-art backbone model for efficient encoding
  • Training & Optimization:
    • Supervised Contrastive Learning: Advanced loss function for representation learning
    • Label Repulsion: Custom loss component to prevent embedding collapse
    • Mixed Precision Training: Efficient training with automatic mixed precision
  • Data & Evaluation:
    • Custom Dataset: ZSHOT-HARDSET for challenging zero-shot evaluation
    • Multi-Label Metrics: Precision, recall, F1-score for comprehensive evaluation
    • Cross-Validation: Rigorous evaluation across multiple domains
  • Deployment & Production:
    • HuggingFace Hub: Model hosting and distribution
    • Gradio Interface: Interactive demo for model exploration
    • FastAPI: Production-ready API endpoints

Project Report

GliZNet successfully demonstrates the effectiveness of label-integrated zero-shot classification, achieving significant performance improvements over baseline approaches. The model was trained on challenging datasets and evaluated across multiple domains, showing robust zero-shot capabilities.

Key achievements include:

  • Single-Pass Efficiency: 10x faster inference compared to cross-encoder baselines
  • Flexible Architecture: Support for both multi-label and multi-class scenarios
  • Advanced Loss Design: Novel combination of contrastive learning and label repulsion
  • Production Ready: Complete pipeline with HuggingFace integration and API endpoints

Outcome and Future Work

  • Research Impact: GliZNet advances the field of zero-shot learning by demonstrating the effectiveness of label integration in transformer architectures. The approach provides a new paradigm for efficient multi-label classification.
  • Practical Applications: The model enables real-world applications in content categorization, intent detection, and automated tagging systems where labeled training data is scarce.
  • Performance Validation: Comprehensive evaluation on benchmark datasets confirms the model’s superiority in zero-shot scenarios, particularly for multi-label tasks.

Future enhancements include:

  • Multi-Lingual Support: Extension to additional languages beyond English and French
  • Domain Adaptation: Fine-tuning capabilities for specific domains
  • Ensemble Methods: Combination with other zero-shot approaches for improved accuracy

Explore GliZNet