Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ A unified, extensible framework for text classification with categorical variabl
## 🚀 Features

- **Complex input support**: Handle text data alongside categorical variables seamlessly.
- **ValueEncoder**: Pass raw string categorical values and labels directly — no manual integer encoding required. Build a `ValueEncoder` from `DictEncoder` or sklearn `LabelEncoder` instances once, and the wrapper handles encoding at train time and label decoding after prediction automatically.
- **Unified yet highly customizable**:
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
Expand All @@ -15,7 +16,9 @@ A unified, extensible framework for text classification with categorical variabl
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
- **Additional features**: explainability using Captum
- **Explainability**:
- **Captum integration**: gradient-based token attribution via integrated gradients (`explain_with_captum=True`).
- **Label attention**: class-specific cross-attention that produces one sentence embedding per class, enabling token-level explanations for each label (`explain_with_label_attention=True`). Enable it by setting `n_heads_label_attention` in `ModelConfig`.


## 📦 Installation
Expand Down Expand Up @@ -57,5 +60,3 @@ See the [examples/](examples/) directory for:
## 📄 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.


121 changes: 111 additions & 10 deletions docs/source/architecture/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,80 @@ At its core, torchTextClassifiers processes data through a simple pipeline:
```

**Data Flow:**
1. **Text** is tokenized into numerical tokens
2. **Tokens** are embedded into dense vectors (with optional attention)
3. **Categorical variables** (optional) are embedded separately
4. **All embeddings** are combined
5. **Classification head** produces final predictions
1. **ValueEncoder** (optional) converts raw string categorical values and labels into integers
2. **Text** is tokenized into numerical tokens
3. **Tokens** are embedded into dense vectors (with optional self-attention)
— or into one embedding *per class* if **label attention** is enabled
4. **Categorical variables** (optional) are embedded separately
5. **All embeddings** are combined
6. **Classification head** produces final predictions
— if a `ValueEncoder` was provided, integer predictions are decoded back to original labels

## Component 0: ValueEncoder (optional preprocessing)

**Purpose:** Encode raw string (or mixed-type) categorical values and labels into
integer indices, and decode predicted integers back to original label values after
inference.

### When to Use

Use `ValueEncoder` whenever your categorical features or labels are stored as strings
(e.g. `"Electronics"`, `"positive"`) rather than integers. Without it, you must
integer-encode inputs manually before passing them to `train` / `predict`.

### Building a ValueEncoder

```python
from sklearn.preprocessing import LabelEncoder
from torchTextClassifiers.value_encoder import DictEncoder, ValueEncoder

# Option A: sklearn LabelEncoder (fit on train data)
cat_encoder = LabelEncoder().fit(X_train_categories)

# Option B: explicit dict mapping
cat_encoder = DictEncoder({"Electronics": 0, "Audio": 1, "Books": 2})

value_encoder = ValueEncoder(
label_encoder=LabelEncoder().fit(y_train), # encodes/decodes labels
categorical_encoders={
"category": cat_encoder, # one entry per categorical column
# "brand": brand_encoder, # add more as needed
},
)
```

### What It Provides

```python
value_encoder.vocabulary_sizes # [3, ...] – inferred from each encoder
value_encoder.num_classes # 2 – inferred from label encoder
```

These are read automatically by `torchTextClassifiers` when constructing the model,
so you don't need to set `num_classes` or `categorical_vocabulary_sizes` in `ModelConfig`
manually.

### Integration with the Wrapper

```python
classifier = torchTextClassifiers(
tokenizer=tokenizer,
model_config=ModelConfig(embedding_dim=64), # num_classes inferred from encoder
value_encoder=value_encoder,
)

# Train with raw string inputs (default: raw_categorical_inputs=True, raw_labels=True)
classifier.train(X_train, y_train, training_config)

# Predict — output labels are decoded back to original strings automatically
result = classifier.predict(X_test)
print(result["prediction"]) # ["positive", "negative", ...]
```

The `ValueEncoder` is saved and reloaded with the model via `classifier.save()` /
`torchTextClassifiers.load()`.

---

## Component 1: Tokenizer

Expand Down Expand Up @@ -144,6 +213,39 @@ embedder = TextEmbedder(config)
- `n_head`: Number of attention heads (typically 4, 8, or 16)
- `n_layer`: Depth of transformer (start with 2-3)

### With Label Attention (Optional Explainability Layer)

Label attention replaces mean-pooling with a **cross-attention mechanism** where each
class has a learnable embedding that attends over the token sequence:

```
Token embeddings (batch, seq_len, d)
↓ cross-attention (labels as queries, tokens as keys/values)
Sentence embeddings (batch, num_classes, d) ← one per class
ClassificationHead (d → 1) ← shared, applied per class
Logits (batch, num_classes)
```

Enable it by setting `n_heads_label_attention` in `ModelConfig`:

```python
model_config = ModelConfig(
embedding_dim=96,
num_classes=6,
n_heads_label_attention=4, # number of attention heads for label attention
)
```

**Benefits:**
- Free explainability at inference time (`explain_with_label_attention=True` in `predict`)
- The returned attention matrix `(batch, n_head, num_classes, seq_len)` shows which
tokens each class focuses on
- Can be combined with self-attention (`attention_config`)

**Constraint:** `embedding_dim` must be divisible by `n_heads_label_attention`.

## Component 3: Categorical Variable Handler

**Purpose:** Process categorical features (like user demographics, product categories) alongside text.
Expand Down Expand Up @@ -276,7 +378,7 @@ head = ClassificationHead(net=custom_head)
## Complete Architecture

```{thumbnail} diagrams/NN.drawio.png
:alt:
:alt:
```

### Full Model Assembly
Expand Down Expand Up @@ -592,9 +694,10 @@ categorical_embedding_dim = min(50, 10 // 2) = 5

torchTextClassifiers provides a **component-based pipeline** for text classification:

0. **ValueEncoder** (optional) → Encodes raw string inputs; decodes predictions back to original labels
1. **Tokenizer** → Converts text to tokens
2. **Text Embedder** → Creates semantic embeddings (with optional attention)
3. **Categorical Handler** → Processes additional features (optional)
2. **Text Embedder** → Creates semantic embeddings (with optional self-attention and/or label attention)
3. **Categorical Handler** (optional) → Processes additional categorical features
4. **Classification Head** → Produces predictions

**Key Benefits:**
Expand All @@ -610,5 +713,3 @@ torchTextClassifiers provides a **component-based pipeline** for text classifica
- **Examples**: Explore complete examples in the repository

Ready to build your classifier? Start with {doc}`../getting_started/quickstart`!


Loading
Loading