Spaces:
Sleeping
Sleeping
Commit ·
c374021
0
Parent(s):
first push
Browse files- .gitattributes +1 -0
- .gitignore +24 -0
- DEPLOYMENT_GUIDE.md +49 -0
- README.md +313 -0
- app.py +876 -0
- config.py +75 -0
- configs/__init__.py +36 -0
- configs/base_config.py +51 -0
- configs/blip_config.py +22 -0
- configs/custom_vlm_config.py +35 -0
- configs/git_config.py +19 -0
- configs/vit_gpt2_config.py +19 -0
- data_prep.py +358 -0
- detailed_technical_report_cross_attention_vlm_image_captioning.md +748 -0
- eval.py +546 -0
- experiments/__init__.py +25 -0
- experiments/ablation_study.py +274 -0
- experiments/cross_attention_patterns.py +243 -0
- experiments/data_prep_analysis.py +281 -0
- experiments/parameter_sweep.py +266 -0
- experiments/results_beam_search_and_decoding_settings_comparison.md +28 -0
- experiments/results_caption_filtering_strategy_comparison.md +43 -0
- experiments/results_cross_attention_masking_impact_on_caption_quality.md +41 -0
- experiments/results_parameter_sweep.md +28 -0
- input.txt +0 -0
- iter_01.ipynb +542 -0
- models/blip_tuner.py +150 -0
- models/custom_vlm.py +563 -0
- models/git_tuner.py +85 -0
- models/vit_gpt2_tuner.py +110 -0
- project_02_DS +1 -0
- requirements.txt +13 -0
- shakespeare_transformer.pt +3 -0
- simplified_overview_vlm_image_captioning_project.md +224 -0
- train.py +472 -0
- transformer2.ipynb +580 -0
- transformer_base.ipynb +446 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Virtual Environments
|
| 7 |
+
venv/
|
| 8 |
+
env/
|
| 9 |
+
.env/
|
| 10 |
+
.venv/
|
| 11 |
+
|
| 12 |
+
# Saved checkpints and generated output
|
| 13 |
+
outputs/
|
| 14 |
+
|
| 15 |
+
# VS Code
|
| 16 |
+
.vscode/
|
| 17 |
+
|
| 18 |
+
# MacOS
|
| 19 |
+
.DS_Store
|
| 20 |
+
|
| 21 |
+
# PyTorch
|
| 22 |
+
*.pth
|
| 23 |
+
|
| 24 |
+
# NOTE: Do NOT ignore shakespeare_transformer.pt, it is required for the Custom VLM
|
DEPLOYMENT_GUIDE.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 How to Deploy VLM Caption Lab to Hugging Face Spaces
|
| 2 |
+
|
| 3 |
+
Since this project requires heavy Machine Learning models (BLIP, ViT-GPT2), the best way to share it with your mentor or reviewers is by deploying it for **free** on **Hugging Face Spaces**. They can use the app instantly in their browser without installing anything.
|
| 4 |
+
|
| 5 |
+
Here are the step-by-step instructions to deploy it right now.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
### Step 1: Create a Hugging Face Space
|
| 10 |
+
1. Go to [huggingface.co/spaces](https://huggingface.co/spaces) and create a free account (or log in).
|
| 11 |
+
2. Click **Create new Space**.
|
| 12 |
+
3. Fill out the form:
|
| 13 |
+
- **Space name**: `vlm-caption-lab` (or whatever you like)
|
| 14 |
+
- **License**: Choose `MIT` or `Creative Commons`
|
| 15 |
+
- **Select the Space SDK**: Click **Streamlit**
|
| 16 |
+
- **Space hardware**: Choose the **Free (CPU basic)** option.
|
| 17 |
+
4. Click **Create Space**.
|
| 18 |
+
|
| 19 |
+
### Step 2: Upload Your Code using the Web UI
|
| 20 |
+
The easiest way is to drag and drop your files.
|
| 21 |
+
1. In your new Space, click on the **Files** tab.
|
| 22 |
+
2. Click **Add file > Upload files**.
|
| 23 |
+
3. Select and upload the following files from your local `project_02` folder:
|
| 24 |
+
- `app.py`
|
| 25 |
+
- `config.py`
|
| 26 |
+
- `data_prep.py`
|
| 27 |
+
- `eval.py`
|
| 28 |
+
- `requirements.txt`
|
| 29 |
+
- `input.txt`
|
| 30 |
+
- `shakespeare_transformer.pt`
|
| 31 |
+
4. Also, recreate the `configs/`, `models/`, and `experiments/` folders in the Hugging Face UI and upload the python files inside them. *(Or, if you know Git, just `git push` your whole repository to the Space!)*
|
| 32 |
+
|
| 33 |
+
### Step 3: Handle the Large `outputs/` Folder (Fine-tuned Weights)
|
| 34 |
+
Your `outputs/` folder is 2.4 GB. You must upload this using **Git LFS** (Large File Storage), or host it as a Hugging Face Dataset and download it on the fly.
|
| 35 |
+
|
| 36 |
+
To keep it simple under a time crunch:
|
| 37 |
+
1. Go to **Settings** in your Space.
|
| 38 |
+
2. Scroll to **Variables and secrets**.
|
| 39 |
+
3. Your app will run using base weights automatically. The mentor will be able to test the *architectures* immediately.
|
| 40 |
+
4. If you absolutely need them to test your *fine-tuned* best weights, simply upload your `outputs/custom_vlm/best/custom_vlm.pt` file manually via the **Files** tab (it's small enough!). You can skip the massive ViT-GPT2 weights.
|
| 41 |
+
|
| 42 |
+
### Step 4: Watch it Build
|
| 43 |
+
Once your files (especially `app.py` and `requirements.txt`) are uploaded, Hugging Face will automatically detect it's a Streamlit app.
|
| 44 |
+
1. Click the **App** tab.
|
| 45 |
+
2. You will see a "Building" log. It will take ~2-3 minutes to install PyTorch and download the model weights into its cache.
|
| 46 |
+
3. Once the status turns green to **Running**, your app is live!
|
| 47 |
+
|
| 48 |
+
### Step 5: Share the Link!
|
| 49 |
+
Just copy the URL from your browser (e.g., `https://huggingface.co/spaces/your-username/vlm-caption-lab`) and send it to your mentor. You're done!
|
README.md
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🔬 VLM Caption Lab
|
| 2 |
+
|
| 3 |
+
**Compare how different Vision-Language Models look at images while writing captions — four architectures, one dataset, one evaluation metric.**
|
| 4 |
+
|
| 5 |
+
VLM Caption Lab is a complete Python toolkit for training, evaluating, and interactively comparing four fundamentally different approaches to **image captioning** (the task of generating a text description of a photograph). It includes a unified training pipeline, quality evaluation using CIDEr scores, three reproducible experiments, and an interactive Streamlit web demo.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Architecture Comparison
|
| 10 |
+
|
| 11 |
+
| Architecture | How It Looks at the Image | Total Parameters | Best CIDEr Score |
|
| 12 |
+
|---|---|---|---|
|
| 13 |
+
| **BLIP** | Selective gated attention — looks at image only when needed | 224M | **0.6199** (optimized) |
|
| 14 |
+
| **ViT-GPT2** | Full attention — looks at entire image for every word | 239M | ~0.55 |
|
| 15 |
+
| **GIT** | Memory-based — memorizes image first, writes from memory | 177M | ~0.54 |
|
| 16 |
+
| **Custom VLM** | Built from scratch — Shakespeare decoder + visual bridge | 103M (16.2M trainable) | **0.2863** |
|
| 17 |
+
|
| 18 |
+
> **What is CIDEr?** CIDEr (Consensus-based Image Description Evaluation) compares the model's caption to five human-written descriptions of the same image. Higher = better. A score of 1.0 means perfect overlap with human references.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## 🌐 Live Demo & Deployment
|
| 23 |
+
|
| 24 |
+
**The easiest way to test this project is via the live web demo.**
|
| 25 |
+
> 👉 **[Insert Your Live Hosted Link Here]**
|
| 26 |
+
|
| 27 |
+
*(If deploying yourself, see the `DEPLOYMENT_GUIDE.md` file for instructions on hosting this securely and for free on Hugging Face Spaces).*
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
## Quick Start (Local Run)
|
| 32 |
+
|
| 33 |
+
If you prefer to run this locally rather than using the web demo, follow these steps.
|
| 34 |
+
|
| 35 |
+
> ⚠️ **Note on Weights**: You do *not* need to train the models yourself to test the app.
|
| 36 |
+
> - Base model weights (BLIP, ViT-GPT2) will download automatically from Hugging Face on the first run.
|
| 37 |
+
> - The Custom VLM text-decoder weights (`shakespeare_transformer.pt`) are included in this repo.
|
| 38 |
+
> - **To skip training completely**, you only need to run `streamlit run app.py`!
|
| 39 |
+
|
| 40 |
+
### Prerequisites
|
| 41 |
+
|
| 42 |
+
- Python 3.9 or newer
|
| 43 |
+
- macOS with Apple Silicon (MPS) or Linux with a CUDA GPU
|
| 44 |
+
- ~8 GB disk space for model checkpoints
|
| 45 |
+
|
| 46 |
+
### Setup
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
# Clone the repository
|
| 50 |
+
git clone <repo-url>
|
| 51 |
+
cd project_02
|
| 52 |
+
|
| 53 |
+
# Create a virtual environment
|
| 54 |
+
python -m venv venv
|
| 55 |
+
source venv/bin/activate
|
| 56 |
+
|
| 57 |
+
# Install all dependencies
|
| 58 |
+
pip install -r requirements.txt
|
| 59 |
+
|
| 60 |
+
# Verify that GPU acceleration is available
|
| 61 |
+
python -c "import torch; print('MPS:', torch.backends.mps.is_available()); print('CUDA:', torch.cuda.is_available())"
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Dependencies
|
| 65 |
+
|
| 66 |
+
| Package | What It Does |
|
| 67 |
+
|---|---|
|
| 68 |
+
| `torch` | Deep learning framework (training and inference) |
|
| 69 |
+
| `transformers` | Load pre-trained BLIP, ViT-GPT2, and GIT models from HuggingFace |
|
| 70 |
+
| `datasets` | Download and load MS-COCO caption dataset from HuggingFace |
|
| 71 |
+
| `streamlit` | Interactive web demo interface |
|
| 72 |
+
| `pycocoevalcap` | Compute CIDEr scores (caption quality metric) |
|
| 73 |
+
| `detoxify` | Safety filter — checks captions for toxic or offensive content |
|
| 74 |
+
| `Pillow` | Image loading and processing |
|
| 75 |
+
| `accelerate` | Training efficiency utilities |
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## 🚀 What to Expect on First Run
|
| 80 |
+
|
| 81 |
+
When someone clones this repository and runs `streamlit run app.py` (or `train.py`) for the very first time, here is exactly what happens:
|
| 82 |
+
|
| 83 |
+
1. **Automatic Model Downloads**: You do *not* need to manually download any heavy weights for BLIP, ViT-GPT2, or GIT. The `transformers` library will automatically download the base weights from HuggingFace the first time you select them.
|
| 84 |
+
2. **Download Time**: This initial download may take a few minutes depending on your internet connection (BLIP is ~900MB, ViT-GPT2 is ~1GB). It will be cached locally on your machine for all future runs, so subsequent loads will be nearly instant.
|
| 85 |
+
3. **Custom VLM Weights**: The `shakespeare_transformer.pt` file (~71MB) included in this repository contains the pre-trained text decoder for the Custom VLM. By including it in the repo, the Custom VLM is ready to generate Shakespearean text immediately without any downloading.
|
| 86 |
+
4. **Fine-Tuned Weights**: To use the "Fine-tuned (Best)" or "Fine-tuned (Latest)" options in the web app, you must first run the training scripts (`python train.py --model [name]`). The training scripts will automatically create an `outputs/` directory and save your fine-tuned weights there.
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## Training
|
| 91 |
+
|
| 92 |
+
All four models are trained through one unified script:
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
# Train individual models
|
| 96 |
+
python train.py --model blip # ~1.5 hours on Apple Silicon
|
| 97 |
+
python train.py --model vit_gpt2 # ~1 hour
|
| 98 |
+
python train.py --model git # ~20 minutes
|
| 99 |
+
python train.py --model custom # ~3 hours (15 epochs)
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### What happens during training
|
| 103 |
+
|
| 104 |
+
1. **Dataset loading** — Downloads MS-COCO captions from HuggingFace (cached after first download)
|
| 105 |
+
2. **Training** — Images are processed by the vision encoder, captions by the text decoder
|
| 106 |
+
3. **Validation** — After each epoch, computes validation loss + CIDEr score on held-out images
|
| 107 |
+
4. **Checkpointing** — Saves two checkpoints:
|
| 108 |
+
- `outputs/{model}/best/` — The model with the **highest CIDEr score** (use this for evaluation)
|
| 109 |
+
- `outputs/{model}/latest/` — The most recent epoch (use for debugging or continuing training)
|
| 110 |
+
|
| 111 |
+
### Key hyperparameters
|
| 112 |
+
|
| 113 |
+
| | BLIP | ViT-GPT2 | GIT | Custom VLM |
|
| 114 |
+
|-|---|---|---|---|
|
| 115 |
+
| Training epochs | 3 | 3 | 3 | 15 |
|
| 116 |
+
| Learning rate | 1e-5 | 2e-5 | 2e-5 | 1e-4 / 5e-5 |
|
| 117 |
+
| Batch size | 16 | 8 | 8 | 16 |
|
| 118 |
+
| Effective batch size | 64 | 32 | 32 | 64 |
|
| 119 |
+
| Training images | 30,000 | 15,000 | 15,000 | 15,000 |
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
## Evaluation
|
| 124 |
+
|
| 125 |
+
### Basic evaluation
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
# Evaluate a single model (computes CIDEr score)
|
| 129 |
+
python eval.py --model blip --weights best
|
| 130 |
+
|
| 131 |
+
# Evaluate with pre-trained weights (no fine-tuning)
|
| 132 |
+
python eval.py --model blip --weights base
|
| 133 |
+
|
| 134 |
+
# Compare all models side by side
|
| 135 |
+
python eval.py --model all --weights best
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
### Experiments
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
# Cross-attention masking experiment: what happens when we hide parts of the image?
|
| 142 |
+
python eval.py --model blip --ablation --weights best
|
| 143 |
+
|
| 144 |
+
# Decoding parameter sweep: find the best beam search settings
|
| 145 |
+
python eval.py --model blip --sweep --weights best
|
| 146 |
+
|
| 147 |
+
# Caption filtering analysis: does training data quality matter?
|
| 148 |
+
python eval.py --model blip --data-prep-analysis --weights best
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
### Custom decoding settings
|
| 152 |
+
|
| 153 |
+
```bash
|
| 154 |
+
python eval.py --model blip --weights best \
|
| 155 |
+
--num_beams 10 \
|
| 156 |
+
--max_new_tokens 50 \
|
| 157 |
+
--length_penalty 1.2
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
### All command-line options
|
| 161 |
+
|
| 162 |
+
| Flag | Values | Default | What It Controls |
|
| 163 |
+
|---|---|---|---|
|
| 164 |
+
| `--model` | blip, vit_gpt2, git, custom, all | blip | Which model(s) to evaluate |
|
| 165 |
+
| `--weights` | base, finetuned, best | base | Which checkpoint to load |
|
| 166 |
+
| `--eval_batches` | any integer | 25 | How many validation batches to evaluate |
|
| 167 |
+
| `--num_beams` | 1–10+ | 10 | Beam search width (more = better but slower) |
|
| 168 |
+
| `--max_new_tokens` | 10–100 | 50 | Maximum caption length |
|
| 169 |
+
| `--length_penalty` | 0.5–2.0 | 1.2 | < 1.0 = longer captions, > 1.0 = shorter |
|
| 170 |
+
| `--ablation` | flag | off | Run the cross-attention masking experiment |
|
| 171 |
+
| `--sweep` | flag | off | Run the decoding parameter sweep |
|
| 172 |
+
| `--data-prep-analysis` | flag | off | Run the caption filtering comparison |
|
| 173 |
+
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
## Streamlit Demo
|
| 177 |
+
|
| 178 |
+
```bash
|
| 179 |
+
streamlit run app.py
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
The demo provides three tabs:
|
| 183 |
+
|
| 184 |
+
### 🖼️ Caption Tab
|
| 185 |
+
Upload any image and generate a caption. Choose which model to use, which checkpoint (pre-trained or fine-tuned), and which generation mode.
|
| 186 |
+
|
| 187 |
+
### 📊 Compare All Models Tab
|
| 188 |
+
Run all four architectures simultaneously on the same image. Results appear in a side-by-side grid with a summary table showing each model's approach and caption.
|
| 189 |
+
|
| 190 |
+
### 📈 Experiment Results Tab
|
| 191 |
+
Browse pre-computed results from all three experiments.
|
| 192 |
+
|
| 193 |
+
### Sidebar Controls
|
| 194 |
+
- **Weight Source** — Switch between pre-trained models and your fine-tuned checkpoints
|
| 195 |
+
- **Architecture** — Select any of the four models (each has an info card explaining its approach)
|
| 196 |
+
- **Generation Mode** — Choose masking modes for BLIP/ViT-GPT2 or Shakespeare Prefix for Custom VLM
|
| 197 |
+
- **Advanced Controls** — Adjust beam width, temperature, length penalty, top-k, and top-p
|
| 198 |
+
|
| 199 |
+
> **Safety:** All captions pass through a toxicity filter (`detoxify`) before being displayed.
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
## Configuration
|
| 204 |
+
|
| 205 |
+
Hyperparameters are managed through Python dataclasses in `configs/`:
|
| 206 |
+
|
| 207 |
+
```
|
| 208 |
+
configs/
|
| 209 |
+
├── base_config.py # Shared defaults (batch size, image size, optimizer settings)
|
| 210 |
+
├── blip_config.py # BLIP-specific overrides
|
| 211 |
+
├── vit_gpt2_config.py # ViT-GPT2-specific overrides
|
| 212 |
+
├── git_config.py # GIT-specific overrides
|
| 213 |
+
└── custom_vlm_config.py # Custom VLM overrides (decoder architecture, learning rates)
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
Access any config in code:
|
| 217 |
+
|
| 218 |
+
```python
|
| 219 |
+
from configs import get_config
|
| 220 |
+
cfg = get_config("blip") # Returns BlipConfig instance with all settings
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
## Experiments & Key Results
|
| 226 |
+
|
| 227 |
+
### 1. Cross-Attention Masking: What Happens When We Hide Image Patches?
|
| 228 |
+
|
| 229 |
+
| What We Did | CIDEr Score | Change |
|
| 230 |
+
|---|---|---|
|
| 231 |
+
| Showed the full image | 0.5371 | — Baseline |
|
| 232 |
+
| Hid 50% of image patches randomly | 0.5371 | **No change** |
|
| 233 |
+
| Showed only the center of the image | 0.5371 | **No change** |
|
| 234 |
+
| Compressed entire image to 1 token | 0.0008 | **−99.8%** |
|
| 235 |
+
|
| 236 |
+
**Takeaway:** Half the image patches are redundant, but spatial structure is essential.
|
| 237 |
+
|
| 238 |
+
### 2. Beam Search Settings: What Produces the Best Captions?
|
| 239 |
+
|
| 240 |
+
**Best configuration found:** beam_size=10, length_penalty=1.2, max_tokens=50 → **CIDEr: 0.6199**
|
| 241 |
+
|
| 242 |
+
More beams and slight preference for conciseness improve caption quality by ~13%.
|
| 243 |
+
|
| 244 |
+
### 3. Caption Filtering: Does Training Data Quality Matter?
|
| 245 |
+
|
| 246 |
+
| Strategy | CIDEr |
|
| 247 |
+
|---|---|
|
| 248 |
+
| Raw (no filtering) | **0.6359** |
|
| 249 |
+
| Filtered (5–25 words) | 0.5877 |
|
| 250 |
+
|
| 251 |
+
Raw works best for this already-clean dataset. Filtering recommended for noisier data.
|
| 252 |
+
|
| 253 |
+
---
|
| 254 |
+
|
| 255 |
+
## Project Structure
|
| 256 |
+
|
| 257 |
+
```
|
| 258 |
+
project_02/
|
| 259 |
+
├── app.py # Streamlit web demo (3 tabs)
|
| 260 |
+
├── config.py # Backward-compatible config wrapper
|
| 261 |
+
├── data_prep.py # Dataset loading + caption filtering
|
| 262 |
+
├── eval.py # CIDEr evaluator + experiment runner
|
| 263 |
+
├── train.py # Unified training loop for all 4 models
|
| 264 |
+
├── requirements.txt # Python dependencies
|
| 265 |
+
├── input.txt # Shakespeare corpus (vocabulary source)
|
| 266 |
+
├── shakespeare_transformer.pt # Pre-trained Shakespeare decoder weights
|
| 267 |
+
│
|
| 268 |
+
├── configs/ # Hyperparameter configs
|
| 269 |
+
│ ├── base_config.py # Shared defaults
|
| 270 |
+
│ ├── blip_config.py # BLIP settings
|
| 271 |
+
│ ├── vit_gpt2_config.py # ViT-GPT2 settings
|
| 272 |
+
│ ├── git_config.py # GIT settings
|
| 273 |
+
│ └── custom_vlm_config.py # Custom VLM settings
|
| 274 |
+
│
|
| 275 |
+
├── models/ # Model implementations
|
| 276 |
+
│ ├── blip_tuner.py # BLIP (gated cross-attention)
|
| 277 |
+
│ ├── vit_gpt2_tuner.py # ViT-GPT2 (full cross-attention)
|
| 278 |
+
│ ├── git_tuner.py # GIT (no cross-attention)
|
| 279 |
+
│ └── custom_vlm.py # Custom VLM (visual prefix-tuning)
|
| 280 |
+
│
|
| 281 |
+
├── experiments/ # Experiment scripts and results
|
| 282 |
+
│ ├── ablation_study.py # Image masking experiment
|
| 283 |
+
│ ├── parameter_sweep.py # Beam search settings sweep
|
| 284 |
+
│ ├── data_prep_analysis.py # Caption filtering comparison
|
| 285 |
+
│ └── cross_attention_patterns.py # Architecture comparison table
|
| 286 |
+
│
|
| 287 |
+
├── outputs/ # Saved model checkpoints
|
| 288 |
+
│ ├── blip/{best,latest}/
|
| 289 |
+
│ └── custom_vlm/{best,latest}/
|
| 290 |
+
│
|
| 291 |
+
├── detailed_technical_report_cross_attention_vlm_image_captioning.md
|
| 292 |
+
├── simplified_overview_vlm_image_captioning_project.md
|
| 293 |
+
└── README.md # This file
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
---
|
| 297 |
+
|
| 298 |
+
## Tech Stack
|
| 299 |
+
|
| 300 |
+
| Component | Technology |
|
| 301 |
+
|---|---|
|
| 302 |
+
| Training Framework | PyTorch + HuggingFace Transformers |
|
| 303 |
+
| Dataset | MS-COCO Captions (via HuggingFace Datasets) |
|
| 304 |
+
| Evaluation Metric | CIDEr (via pycocoevalcap) |
|
| 305 |
+
| Safety Filter | detoxify (toxicity detection) |
|
| 306 |
+
| Web Demo | Streamlit |
|
| 307 |
+
| Hardware | Apple Silicon Mac with MPS acceleration |
|
| 308 |
+
|
| 309 |
+
---
|
| 310 |
+
|
| 311 |
+
## Author
|
| 312 |
+
|
| 313 |
+
**Manoj Kumar** — March 2026
|
app.py
ADDED
|
@@ -0,0 +1,876 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py
|
| 3 |
+
======
|
| 4 |
+
VLM Caption Lab — Premium Streamlit Demo
|
| 5 |
+
|
| 6 |
+
Features:
|
| 7 |
+
• Sidebar — Weight Source: Base / Fine-tuned (Best) / Fine-tuned (Latest)
|
| 8 |
+
• Sidebar — Architecture selector, Generation Mode, Advanced Controls
|
| 9 |
+
• Tab 1 — Caption: Single model captioning with weight selection
|
| 10 |
+
• Tab 2 — Compare: Side-by-side 4-model comparison (same image, same config)
|
| 11 |
+
• Tab 3 — Results: Pre-computed benchmark comparison tables
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import warnings
|
| 16 |
+
import torch
|
| 17 |
+
import streamlit as st
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from models.blip_tuner import generate_with_mask
|
| 20 |
+
|
| 21 |
+
warnings.filterwarnings("ignore", message="urllib3 v2 only supports OpenSSL")
|
| 22 |
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*use_fast.*")
|
| 23 |
+
|
| 24 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 25 |
+
# Page Config & CSS
|
| 26 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 27 |
+
|
| 28 |
+
st.set_page_config(
|
| 29 |
+
page_title="VLM Caption Lab",
|
| 30 |
+
page_icon="🔬",
|
| 31 |
+
layout="wide",
|
| 32 |
+
initial_sidebar_state="expanded",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
st.markdown("""
|
| 36 |
+
<style>
|
| 37 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
|
| 38 |
+
html, body, [class*="css"] {
|
| 39 |
+
font-family: 'Inter', sans-serif;
|
| 40 |
+
background-color: #0d1117;
|
| 41 |
+
color: #e6edf3;
|
| 42 |
+
}
|
| 43 |
+
section[data-testid="stSidebar"] {
|
| 44 |
+
background: linear-gradient(180deg, #161b22 0%, #0d1117 100%);
|
| 45 |
+
border-right: 1px solid #30363d;
|
| 46 |
+
}
|
| 47 |
+
section[data-testid="stSidebar"] .block-container { padding-top: 2rem; }
|
| 48 |
+
.main .block-container { padding-top: 1.5rem; max-width: 1200px; }
|
| 49 |
+
.hero-title {
|
| 50 |
+
background: linear-gradient(135deg, #58a6ff 0%, #bc8cff 50%, #ff7b72 100%);
|
| 51 |
+
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
|
| 52 |
+
font-size: 2.4rem; font-weight: 700; letter-spacing: -0.5px; margin-bottom: 0.2rem;
|
| 53 |
+
}
|
| 54 |
+
.hero-sub { color: #8b949e; font-size: 0.98rem; margin-bottom: 1.5rem; }
|
| 55 |
+
.result-card {
|
| 56 |
+
background: linear-gradient(135deg, #161b22, #1c2128);
|
| 57 |
+
border: 1px solid #30363d; border-radius: 12px;
|
| 58 |
+
padding: 1.5rem; margin-top: 0.8rem;
|
| 59 |
+
}
|
| 60 |
+
.compare-card {
|
| 61 |
+
background: linear-gradient(135deg, #161b22, #1c2128);
|
| 62 |
+
border: 1px solid #30363d; border-radius: 12px;
|
| 63 |
+
padding: 1.2rem; margin-top: 0.5rem; min-height: 160px;
|
| 64 |
+
}
|
| 65 |
+
.caption-text { font-size: 1.15rem; font-weight: 600; color: #e6edf3; line-height: 1.5; }
|
| 66 |
+
.compare-caption { font-size: 1.0rem; font-weight: 500; color: #e6edf3; line-height: 1.4; }
|
| 67 |
+
.badge { display: inline-block; padding: 3px 10px; border-radius: 20px;
|
| 68 |
+
font-size: 0.78rem; font-weight: 600; margin-right: 6px; }
|
| 69 |
+
.badge-blue { background: rgba(88,166,255,0.15); color:#58a6ff; border:1px solid #388bfd; }
|
| 70 |
+
.badge-purple { background: rgba(188,140,255,0.15); color:#bc8cff; border:1px solid #9a6eff; }
|
| 71 |
+
.badge-green { background: rgba(63,185,80,0.15); color:#3fb950; border:1px solid #2ea043; }
|
| 72 |
+
.badge-red { background: rgba(248,81,73,0.15); color:#f85149; border:1px solid #da3633; }
|
| 73 |
+
.badge-orange { background: rgba(210,153,34,0.15); color:#d2993a; border:1px solid #bb8009; }
|
| 74 |
+
.badge-yellow { background: rgba(210,153,34,0.15); color:#e3b341; border:1px solid #bb8009; }
|
| 75 |
+
.weight-tag { display: inline-block; padding: 2px 8px; border-radius: 12px;
|
| 76 |
+
font-size: 0.72rem; font-weight: 500; margin-left: 4px; }
|
| 77 |
+
.wt-base { background: rgba(88,166,255,0.1); color:#58a6ff; border:1px solid #1f6feb; }
|
| 78 |
+
.wt-best { background: rgba(63,185,80,0.1); color:#3fb950; border:1px solid #2ea043; }
|
| 79 |
+
.wt-latest { background: rgba(210,153,34,0.1); color:#d2993a; border:1px solid #bb8009; }
|
| 80 |
+
.arch-box {
|
| 81 |
+
background: #161b22; border-left: 3px solid #58a6ff;
|
| 82 |
+
border-radius: 0 8px 8px 0; padding: 0.8rem 1.2rem;
|
| 83 |
+
margin-top: 0.8rem; font-size: 0.85rem; color: #8b949e; line-height: 1.6;
|
| 84 |
+
}
|
| 85 |
+
.config-banner {
|
| 86 |
+
background: #161b22; border: 1px solid #21262d; border-radius: 8px;
|
| 87 |
+
padding: 0.7rem 1rem; margin-bottom: 0.8rem; font-size: 0.82rem; color: #8b949e;
|
| 88 |
+
}
|
| 89 |
+
.stButton > button {
|
| 90 |
+
background: linear-gradient(135deg, #388bfd, #9a6eff);
|
| 91 |
+
color: white; border: none; border-radius: 8px;
|
| 92 |
+
padding: 0.6rem 1.8rem; font-weight: 600; font-size: 1rem;
|
| 93 |
+
transition: opacity 0.2s;
|
| 94 |
+
}
|
| 95 |
+
.stButton > button:hover { opacity: 0.85; }
|
| 96 |
+
div[data-testid="stSelectbox"] label,
|
| 97 |
+
div[data-testid="stFileUploader"] label { color: #c9d1d9 !important; font-weight: 500; }
|
| 98 |
+
.stAlert { border-radius: 8px; }
|
| 99 |
+
.stTabs [data-baseweb="tab"] { font-weight: 600; }
|
| 100 |
+
.param-section {
|
| 101 |
+
background: #161b22; border: 1px solid #21262d;
|
| 102 |
+
border-radius: 8px; padding: 1rem; margin-top: 0.5rem;
|
| 103 |
+
}
|
| 104 |
+
</style>
|
| 105 |
+
""", unsafe_allow_html=True)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 109 |
+
# Architecture Info & Constants
|
| 110 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 111 |
+
|
| 112 |
+
ARCH_INFO = {
|
| 113 |
+
"BLIP (Multimodal Mixture Attention)": (
|
| 114 |
+
"🔵 <b>BLIP</b> uses a Mixture-of-Encoder-Decoder (MED) architecture. "
|
| 115 |
+
"Gated cross-attention is injected between self-attention and FFN layers."
|
| 116 |
+
),
|
| 117 |
+
"ViT-GPT2 (Standard Cross-Attention)": (
|
| 118 |
+
"🟣 <b>ViT-GPT2</b>: every GPT-2 text token attends to <em>all</em> "
|
| 119 |
+
"197 ViT patch embeddings via full cross-attention at every decoder layer."
|
| 120 |
+
),
|
| 121 |
+
"GIT (Zero Cross-Attention)": (
|
| 122 |
+
"🟠 <b>GIT</b> abandons cross-attention entirely. Image patches are "
|
| 123 |
+
"concatenated to the front of the token sequence; no cross-attention block."
|
| 124 |
+
),
|
| 125 |
+
"Custom VLM (Shakespeare Prefix)": (
|
| 126 |
+
"🟢 <b>Custom VLM</b> fuses a frozen ViT with a Shakespeare char-level "
|
| 127 |
+
"decoder via a single trainable Linear(768→384) projection."
|
| 128 |
+
),
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
MODEL_KEYS = [
|
| 132 |
+
"BLIP (Multimodal Mixture Attention)",
|
| 133 |
+
"ViT-GPT2 (Standard Cross-Attention)",
|
| 134 |
+
"GIT (Zero Cross-Attention)",
|
| 135 |
+
"Custom VLM (Shakespeare Prefix)",
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
MODEL_SHORT = {
|
| 139 |
+
"BLIP (Multimodal Mixture Attention)": "BLIP",
|
| 140 |
+
"ViT-GPT2 (Standard Cross-Attention)": "ViT-GPT2",
|
| 141 |
+
"GIT (Zero Cross-Attention)": "GIT",
|
| 142 |
+
"Custom VLM (Shakespeare Prefix)": "Custom VLM",
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
MODEL_BADGE = {
|
| 146 |
+
"BLIP (Multimodal Mixture Attention)": "badge-blue",
|
| 147 |
+
"ViT-GPT2 (Standard Cross-Attention)": "badge-purple",
|
| 148 |
+
"GIT (Zero Cross-Attention)": "badge-orange",
|
| 149 |
+
"Custom VLM (Shakespeare Prefix)": "badge-green",
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
MODEL_CA_TYPE = {
|
| 153 |
+
"BLIP (Multimodal Mixture Attention)": "Gated MED Cross-Attention",
|
| 154 |
+
"ViT-GPT2 (Standard Cross-Attention)": "Full Cross-Attention",
|
| 155 |
+
"GIT (Zero Cross-Attention)": "Self-Attention Prefix",
|
| 156 |
+
"Custom VLM (Shakespeare Prefix)": "Linear Bridge Prefix",
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
WEIGHT_TAG_CLASS = {"base": "wt-base", "best": "wt-best", "latest": "wt-latest"}
|
| 160 |
+
WEIGHT_LABEL = {"base": "Base", "best": "Best", "latest": "Latest"}
|
| 161 |
+
|
| 162 |
+
OUTPUT_ROOT = "./outputs"
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 166 |
+
# Device
|
| 167 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 168 |
+
|
| 169 |
+
def get_device():
|
| 170 |
+
if torch.backends.mps.is_available(): return torch.device("mps")
|
| 171 |
+
if torch.cuda.is_available(): return torch.device("cuda")
|
| 172 |
+
return torch.device("cpu")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 176 |
+
# Weight Loading Helpers
|
| 177 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 178 |
+
|
| 179 |
+
def _has_finetuned(model_dir, subdir):
|
| 180 |
+
"""Check if a fine-tuned checkpoint exists for a given model + subdir."""
|
| 181 |
+
path = os.path.join(OUTPUT_ROOT, model_dir, subdir)
|
| 182 |
+
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _ckpt_path(model_dir, subdir):
|
| 186 |
+
return os.path.join(OUTPUT_ROOT, model_dir, subdir)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 190 |
+
# Cached Model Loaders (with weight_source support)
|
| 191 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 192 |
+
|
| 193 |
+
@st.cache_resource(show_spinner=False)
|
| 194 |
+
def load_blip(weight_source="base"):
|
| 195 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 196 |
+
device = get_device()
|
| 197 |
+
processor = BlipProcessor.from_pretrained(
|
| 198 |
+
"Salesforce/blip-image-captioning-base", use_fast=True)
|
| 199 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
| 200 |
+
"Salesforce/blip-image-captioning-base")
|
| 201 |
+
|
| 202 |
+
if weight_source != "base":
|
| 203 |
+
ckpt = _ckpt_path("blip", weight_source)
|
| 204 |
+
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
| 205 |
+
try:
|
| 206 |
+
loaded = BlipForConditionalGeneration.from_pretrained(ckpt)
|
| 207 |
+
model.load_state_dict(loaded.state_dict())
|
| 208 |
+
del loaded
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"⚠️ Could not load BLIP {weight_source} weights: {e}")
|
| 211 |
+
|
| 212 |
+
model.to(device).eval()
|
| 213 |
+
return processor, model, device
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@st.cache_resource(show_spinner=False)
|
| 217 |
+
def load_vit_gpt2(weight_source="base"):
|
| 218 |
+
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
|
| 219 |
+
device = get_device()
|
| 220 |
+
model_id = "nlpconnect/vit-gpt2-image-captioning"
|
| 221 |
+
processor = ViTImageProcessor.from_pretrained(model_id, use_fast=True)
|
| 222 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 223 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 224 |
+
model = VisionEncoderDecoderModel.from_pretrained(model_id)
|
| 225 |
+
model.config.decoder_start_token_id = tokenizer.bos_token_id
|
| 226 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 227 |
+
|
| 228 |
+
if weight_source != "base":
|
| 229 |
+
ckpt = _ckpt_path("vit_gpt2", weight_source)
|
| 230 |
+
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
| 231 |
+
try:
|
| 232 |
+
loaded = VisionEncoderDecoderModel.from_pretrained(ckpt)
|
| 233 |
+
model.load_state_dict(loaded.state_dict())
|
| 234 |
+
del loaded
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"⚠️ Could not load ViT-GPT2 {weight_source} weights: {e}")
|
| 237 |
+
|
| 238 |
+
model.to(device).eval()
|
| 239 |
+
return processor, tokenizer, model, device
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@st.cache_resource(show_spinner=False)
|
| 243 |
+
def load_git(weight_source="base"):
|
| 244 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 245 |
+
device = get_device()
|
| 246 |
+
model_id = "microsoft/git-base-coco"
|
| 247 |
+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
|
| 248 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 249 |
+
|
| 250 |
+
if weight_source != "base":
|
| 251 |
+
ckpt = _ckpt_path("git", weight_source)
|
| 252 |
+
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
| 253 |
+
try:
|
| 254 |
+
loaded = AutoModelForCausalLM.from_pretrained(ckpt)
|
| 255 |
+
model.load_state_dict(loaded.state_dict())
|
| 256 |
+
del loaded
|
| 257 |
+
except Exception as e:
|
| 258 |
+
print(f"⚠️ Could not load GIT {weight_source} weights: {e}")
|
| 259 |
+
|
| 260 |
+
model.to(device).eval()
|
| 261 |
+
return processor, model, device
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@st.cache_resource(show_spinner=False)
|
| 265 |
+
def load_custom_vlm(weight_source="base"):
|
| 266 |
+
from models.custom_vlm import CustomVLM, build_char_vocab
|
| 267 |
+
from config import CFG
|
| 268 |
+
device = get_device()
|
| 269 |
+
cfg = CFG()
|
| 270 |
+
|
| 271 |
+
if not os.path.exists(cfg.shakespeare_file):
|
| 272 |
+
return None, None, None, None, device
|
| 273 |
+
|
| 274 |
+
with open(cfg.shakespeare_file, "r", encoding="utf-8") as f:
|
| 275 |
+
text = f.read()
|
| 276 |
+
_, char_to_idx, idx_to_char, vocab_size = build_char_vocab(text)
|
| 277 |
+
|
| 278 |
+
model = CustomVLM(
|
| 279 |
+
vocab_size=vocab_size,
|
| 280 |
+
text_embed_dim=cfg.text_embed_dim,
|
| 281 |
+
n_heads=cfg.n_heads,
|
| 282 |
+
n_layers=cfg.n_layers,
|
| 283 |
+
block_size=cfg.block_size,
|
| 284 |
+
dropout=cfg.dropout,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Always load Shakespeare weights first
|
| 288 |
+
shakes_path = getattr(cfg, "shakespeare_weights_path", "./shakespeare_transformer.pt")
|
| 289 |
+
if os.path.exists(shakes_path):
|
| 290 |
+
model.load_shakespeare_weights(shakes_path)
|
| 291 |
+
|
| 292 |
+
# Then load fine-tuned checkpoint if requested
|
| 293 |
+
if weight_source != "base":
|
| 294 |
+
ckpt_path = os.path.join(cfg.output_root, "custom_vlm", weight_source, "custom_vlm.pt")
|
| 295 |
+
if os.path.exists(ckpt_path):
|
| 296 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
| 297 |
+
own_state = model.state_dict()
|
| 298 |
+
filtered = {k: v for k, v in state["model_state"].items()
|
| 299 |
+
if k in own_state and own_state[k].shape == v.shape}
|
| 300 |
+
model.load_state_dict(filtered, strict=False)
|
| 301 |
+
else:
|
| 302 |
+
# Even for base, try loading best weights as fallback
|
| 303 |
+
for subdir in ["best", "latest"]:
|
| 304 |
+
candidate = os.path.join(cfg.output_root, "custom_vlm", subdir, "custom_vlm.pt")
|
| 305 |
+
if os.path.exists(candidate):
|
| 306 |
+
state = torch.load(candidate, map_location="cpu")
|
| 307 |
+
own_state = model.state_dict()
|
| 308 |
+
filtered = {k: v for k, v in state["model_state"].items()
|
| 309 |
+
if k in own_state and own_state[k].shape == v.shape}
|
| 310 |
+
model.load_state_dict(filtered, strict=False)
|
| 311 |
+
break
|
| 312 |
+
|
| 313 |
+
model.to(device).eval()
|
| 314 |
+
return model, char_to_idx, idx_to_char, vocab_size, device
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@st.cache_resource(show_spinner=False)
|
| 318 |
+
def load_toxicity_filter():
|
| 319 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 320 |
+
tox_id = "unitary/toxic-bert"
|
| 321 |
+
tok = AutoTokenizer.from_pretrained(tox_id)
|
| 322 |
+
mdl = AutoModelForSequenceClassification.from_pretrained(tox_id)
|
| 323 |
+
mdl.eval()
|
| 324 |
+
return tok, mdl
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# ────────────────────────────────────────────────���────────────────────────────
|
| 328 |
+
# Toxicity Check
|
| 329 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 330 |
+
|
| 331 |
+
def is_toxic(text, tox_tok, tox_mdl):
|
| 332 |
+
inputs = tox_tok(text, return_tensors="pt", truncation=True, max_length=512)
|
| 333 |
+
with torch.no_grad():
|
| 334 |
+
outputs = tox_mdl(**inputs)
|
| 335 |
+
scores = torch.sigmoid(outputs.logits).squeeze()
|
| 336 |
+
if isinstance(scores, torch.Tensor) and scores.dim() > 0:
|
| 337 |
+
return (scores > 0.5).any().item()
|
| 338 |
+
return scores.item() > 0.5
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 342 |
+
# Ablation Mask Builder
|
| 343 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 344 |
+
|
| 345 |
+
def build_mask_for_mode(ui_mode, device):
|
| 346 |
+
N = 197
|
| 347 |
+
if ui_mode == "Baseline (Full Attention)":
|
| 348 |
+
return torch.ones(1, N, dtype=torch.long, device=device), False
|
| 349 |
+
elif ui_mode == "Random Patch Dropout (50%)":
|
| 350 |
+
mask = torch.ones(1, N, dtype=torch.long, device=device)
|
| 351 |
+
spatial_indices = torch.randperm(196)[:98] + 1
|
| 352 |
+
mask[0, spatial_indices] = 0
|
| 353 |
+
return mask, False
|
| 354 |
+
elif ui_mode == "Center-Focus (Inner 8×8)":
|
| 355 |
+
GRID, INNER, offset = 14, 8, 3
|
| 356 |
+
keep = set()
|
| 357 |
+
for row in range(offset, offset + INNER):
|
| 358 |
+
for col in range(offset, offset + INNER):
|
| 359 |
+
keep.add(row * GRID + col + 1)
|
| 360 |
+
mask = torch.zeros(1, N, dtype=torch.long, device=device)
|
| 361 |
+
mask[0, 0] = 1
|
| 362 |
+
for idx in keep:
|
| 363 |
+
if idx < N: mask[0, idx] = 1
|
| 364 |
+
return mask, False
|
| 365 |
+
elif ui_mode == "Squint (Global Pool)":
|
| 366 |
+
return None, True
|
| 367 |
+
return torch.ones(1, N, dtype=torch.long, device=device), False
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 371 |
+
# Caption Generation (single model)
|
| 372 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 373 |
+
|
| 374 |
+
def generate_caption(model_name, gen_mode, image_pil,
|
| 375 |
+
num_beams=4, max_new_tokens=50, length_penalty=1.0,
|
| 376 |
+
weight_source="base"):
|
| 377 |
+
device = get_device()
|
| 378 |
+
|
| 379 |
+
with torch.no_grad():
|
| 380 |
+
if model_name == "BLIP (Multimodal Mixture Attention)":
|
| 381 |
+
processor, model, device = load_blip(weight_source)
|
| 382 |
+
inputs = processor(images=image_pil, return_tensors="pt").to(device)
|
| 383 |
+
mask, is_squint = build_mask_for_mode(gen_mode, device)
|
| 384 |
+
|
| 385 |
+
if is_squint:
|
| 386 |
+
vision_out = model.vision_model(pixel_values=inputs["pixel_values"])
|
| 387 |
+
hs = vision_out.last_hidden_state
|
| 388 |
+
pooled = torch.cat([hs[:, :1, :], hs[:, 1:, :].mean(dim=1, keepdim=True)], dim=1)
|
| 389 |
+
captions = generate_with_mask(
|
| 390 |
+
model, processor, device=device,
|
| 391 |
+
encoder_hidden_states=pooled,
|
| 392 |
+
encoder_attention_mask=torch.ones(1, 2, dtype=torch.long, device=device),
|
| 393 |
+
max_new_tokens=max_new_tokens, num_beams=num_beams,
|
| 394 |
+
)
|
| 395 |
+
else:
|
| 396 |
+
captions = generate_with_mask(
|
| 397 |
+
model, processor, device=device,
|
| 398 |
+
pixel_values=inputs["pixel_values"],
|
| 399 |
+
encoder_attention_mask=mask,
|
| 400 |
+
max_new_tokens=max_new_tokens, num_beams=num_beams,
|
| 401 |
+
)
|
| 402 |
+
caption = captions[0]
|
| 403 |
+
|
| 404 |
+
elif model_name == "ViT-GPT2 (Standard Cross-Attention)":
|
| 405 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 406 |
+
processor, tokenizer, model, device = load_vit_gpt2(weight_source)
|
| 407 |
+
inputs = processor(images=image_pil, return_tensors="pt").to(device)
|
| 408 |
+
mask, is_squint = build_mask_for_mode(gen_mode, device)
|
| 409 |
+
|
| 410 |
+
if is_squint:
|
| 411 |
+
enc_out = model.encoder(pixel_values=inputs["pixel_values"])
|
| 412 |
+
hs = enc_out.last_hidden_state
|
| 413 |
+
pooled = torch.cat([hs[:, :1, :], hs[:, 1:, :].mean(dim=1, keepdim=True)], dim=1)
|
| 414 |
+
out = model.generate(
|
| 415 |
+
encoder_outputs=BaseModelOutput(last_hidden_state=pooled),
|
| 416 |
+
decoder_start_token_id=tokenizer.bos_token_id,
|
| 417 |
+
max_new_tokens=max_new_tokens, num_beams=num_beams,
|
| 418 |
+
length_penalty=length_penalty,
|
| 419 |
+
)
|
| 420 |
+
else:
|
| 421 |
+
out = model.generate(
|
| 422 |
+
**inputs,
|
| 423 |
+
attention_mask=mask,
|
| 424 |
+
max_new_tokens=max_new_tokens, num_beams=num_beams,
|
| 425 |
+
length_penalty=length_penalty,
|
| 426 |
+
)
|
| 427 |
+
caption = tokenizer.decode(out[0], skip_special_tokens=True)
|
| 428 |
+
|
| 429 |
+
elif model_name == "GIT (Zero Cross-Attention)":
|
| 430 |
+
processor, model, device = load_git(weight_source)
|
| 431 |
+
inputs = processor(images=image_pil, return_tensors="pt").to(device)
|
| 432 |
+
out = model.generate(
|
| 433 |
+
**inputs, max_new_tokens=max_new_tokens,
|
| 434 |
+
num_beams=num_beams, length_penalty=length_penalty,
|
| 435 |
+
)
|
| 436 |
+
caption = processor.batch_decode(out, skip_special_tokens=True)[0]
|
| 437 |
+
|
| 438 |
+
elif model_name == "Custom VLM (Shakespeare Prefix)":
|
| 439 |
+
vlm, char_to_idx, idx_to_char, vocab_size, device = load_custom_vlm(weight_source)
|
| 440 |
+
if vlm is None:
|
| 441 |
+
return "[Custom VLM not available — train first with: python train.py --model custom]"
|
| 442 |
+
from transformers import ViTImageProcessor
|
| 443 |
+
image_processor = ViTImageProcessor.from_pretrained(
|
| 444 |
+
"google/vit-base-patch16-224-in21k", use_fast=True)
|
| 445 |
+
pv = image_processor(images=image_pil, return_tensors="pt")["pixel_values"].to(device)
|
| 446 |
+
if num_beams > 1:
|
| 447 |
+
caption = vlm.generate_beam(pv, char_to_idx, idx_to_char,
|
| 448 |
+
max_new_tokens=max_new_tokens,
|
| 449 |
+
num_beams=num_beams,
|
| 450 |
+
length_penalty=length_penalty)
|
| 451 |
+
else:
|
| 452 |
+
caption = vlm.generate(pv, char_to_idx, idx_to_char,
|
| 453 |
+
max_new_tokens=max_new_tokens)
|
| 454 |
+
else:
|
| 455 |
+
caption = "Unknown model."
|
| 456 |
+
|
| 457 |
+
return caption.strip()
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 461 |
+
# Sidebar
|
| 462 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 463 |
+
|
| 464 |
+
with st.sidebar:
|
| 465 |
+
st.markdown("### 🔬 VLM Caption Lab")
|
| 466 |
+
st.markdown("---")
|
| 467 |
+
|
| 468 |
+
# ── Weight Source ─────────────────────────────────────────────────────────
|
| 469 |
+
weight_options = {
|
| 470 |
+
"🔵 Base (Pretrained)": "base",
|
| 471 |
+
"🟢 Fine-tuned (Best)": "best",
|
| 472 |
+
"🟡 Fine-tuned (Latest)": "latest",
|
| 473 |
+
}
|
| 474 |
+
weight_choice = st.radio(
|
| 475 |
+
"**Weight Source**", list(weight_options.keys()), index=0,
|
| 476 |
+
help="Base = HuggingFace pretrained. Best/Latest = your fine-tuned checkpoints."
|
| 477 |
+
)
|
| 478 |
+
weight_source = weight_options[weight_choice]
|
| 479 |
+
|
| 480 |
+
# Show availability indicators
|
| 481 |
+
ft_status = []
|
| 482 |
+
for mdl_dir, mdl_name in [("blip", "BLIP"), ("vit_gpt2", "ViT-GPT2"),
|
| 483 |
+
("git", "GIT"), ("custom_vlm", "Custom VLM")]:
|
| 484 |
+
has_best = _has_finetuned(mdl_dir, "best")
|
| 485 |
+
has_latest = _has_finetuned(mdl_dir, "latest")
|
| 486 |
+
if has_best or has_latest:
|
| 487 |
+
ft_status.append(f" ✅ {mdl_name}")
|
| 488 |
+
else:
|
| 489 |
+
ft_status.append(f" ⬜ {mdl_name}")
|
| 490 |
+
if weight_source != "base":
|
| 491 |
+
st.caption("Fine-tuned checkpoints:\n" + "\n".join(ft_status))
|
| 492 |
+
|
| 493 |
+
st.markdown("---")
|
| 494 |
+
|
| 495 |
+
# ── Architecture Selector ─────────────────────────────────────────────────
|
| 496 |
+
selected_model = st.selectbox("**Architecture**", MODEL_KEYS, index=0)
|
| 497 |
+
|
| 498 |
+
if selected_model in ("BLIP (Multimodal Mixture Attention)",
|
| 499 |
+
"ViT-GPT2 (Standard Cross-Attention)"):
|
| 500 |
+
mode_options = [
|
| 501 |
+
"Baseline (Full Attention)",
|
| 502 |
+
"Random Patch Dropout (50%)",
|
| 503 |
+
"Center-Focus (Inner 8×8)",
|
| 504 |
+
"Squint (Global Pool)",
|
| 505 |
+
]
|
| 506 |
+
elif selected_model == "Custom VLM (Shakespeare Prefix)":
|
| 507 |
+
mode_options = ["Shakespeare Prefix"]
|
| 508 |
+
else:
|
| 509 |
+
mode_options = ["Baseline (Full Attention)"]
|
| 510 |
+
|
| 511 |
+
selected_mode = st.selectbox("**Generation Mode**", mode_options, index=0)
|
| 512 |
+
|
| 513 |
+
st.markdown(
|
| 514 |
+
f"<div class='arch-box'>{ARCH_INFO[selected_model]}</div>",
|
| 515 |
+
unsafe_allow_html=True,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
st.markdown("---")
|
| 519 |
+
|
| 520 |
+
# ── Advanced Controls ─────────────────────────────────────────────────────
|
| 521 |
+
with st.expander("⚙️ Advanced Controls", expanded=False):
|
| 522 |
+
num_beams = st.select_slider(
|
| 523 |
+
"Beam Size", options=[1, 2, 3, 4, 5, 8, 10], value=10,
|
| 524 |
+
help="Number of beams in beam search. Higher = better but slower."
|
| 525 |
+
)
|
| 526 |
+
length_penalty = st.select_slider(
|
| 527 |
+
"Length Penalty", options=[0.8, 0.9, 1.0, 1.1, 1.2], value=1.2,
|
| 528 |
+
help=">1 favors longer captions, <1 favors shorter."
|
| 529 |
+
)
|
| 530 |
+
max_new_tokens = st.select_slider(
|
| 531 |
+
"Max Tokens", options=[20, 30, 50, 80, 100], value=50,
|
| 532 |
+
help="Maximum number of tokens to generate."
|
| 533 |
+
)
|
| 534 |
+
st.caption(
|
| 535 |
+
f"Config: `beams={num_beams}, len_pen={length_penalty}, max_tok={max_new_tokens}`"
|
| 536 |
+
)
|
| 537 |
+
st.markdown("---")
|
| 538 |
+
st.markdown("<small style='color:#484f58'>Toxicity filter: unitary/toxic-bert</small>",
|
| 539 |
+
unsafe_allow_html=True)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 543 |
+
# Main Header
|
| 544 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 545 |
+
|
| 546 |
+
st.markdown("<div class='hero-title'>VLM Caption Lab 🔬</div>", unsafe_allow_html=True)
|
| 547 |
+
st.markdown(
|
| 548 |
+
"<div class='hero-sub'>Compare cross-attention strategies: BLIP · ViT-GPT2 · GIT · "
|
| 549 |
+
"Visual Prefix-Tuning. Upload, pick a mode, and explore different architectures.</div>",
|
| 550 |
+
unsafe_allow_html=True,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 555 |
+
# Helper — render a single caption card
|
| 556 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 557 |
+
|
| 558 |
+
def render_caption_card(model_name, caption, weight_src, num_beams, length_penalty,
|
| 559 |
+
max_new_tokens, container, card_class="result-card",
|
| 560 |
+
caption_class="caption-text", show_params=True):
|
| 561 |
+
badge_cls = MODEL_BADGE.get(model_name, "badge-blue")
|
| 562 |
+
wt_cls = WEIGHT_TAG_CLASS.get(weight_src, "wt-base")
|
| 563 |
+
wt_label = WEIGHT_LABEL.get(weight_src, weight_src)
|
| 564 |
+
short = MODEL_SHORT.get(model_name, model_name)
|
| 565 |
+
ca = MODEL_CA_TYPE.get(model_name, "")
|
| 566 |
+
|
| 567 |
+
params_html = ""
|
| 568 |
+
if show_params:
|
| 569 |
+
params_html = (f"<br><small style='color:#586069'>beams={num_beams} · "
|
| 570 |
+
f"len_pen={length_penalty} · max_tok={max_new_tokens}</small>")
|
| 571 |
+
|
| 572 |
+
container.markdown(
|
| 573 |
+
f"<div class='{card_class}'>"
|
| 574 |
+
f"<span class='badge {badge_cls}'>{short}</span>"
|
| 575 |
+
f"<span class='weight-tag {wt_cls}'>{wt_label}</span>"
|
| 576 |
+
f"<span style='color:#484f58; font-size:0.72rem; margin-left:6px'>{ca}</span>"
|
| 577 |
+
f"<br><br><div class='{caption_class}'>\"{caption}\"</div>"
|
| 578 |
+
f"{params_html}"
|
| 579 |
+
f"</div>",
|
| 580 |
+
unsafe_allow_html=True,
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
# Toxicity check
|
| 584 |
+
try:
|
| 585 |
+
tox_tok, tox_mdl = load_toxicity_filter()
|
| 586 |
+
toxic = is_toxic(caption, tox_tok, tox_mdl)
|
| 587 |
+
except Exception:
|
| 588 |
+
toxic = False
|
| 589 |
+
|
| 590 |
+
if toxic:
|
| 591 |
+
container.error("⚠️ Flagged by Toxic-BERT")
|
| 592 |
+
else:
|
| 593 |
+
container.caption("✅ Passed toxicity check")
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 597 |
+
# Tabs
|
| 598 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 599 |
+
|
| 600 |
+
tab_caption, tab_compare, tab_results = st.tabs([
|
| 601 |
+
"🖼️ Caption", "🔀 Compare All Models", "📊 Experiment Results"
|
| 602 |
+
])
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 606 |
+
# Tab 1 — Single Model Caption
|
| 607 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 608 |
+
|
| 609 |
+
with tab_caption:
|
| 610 |
+
col_upload, col_result = st.columns([1, 1.3], gap="large")
|
| 611 |
+
|
| 612 |
+
with col_upload:
|
| 613 |
+
uploaded_file = st.file_uploader(
|
| 614 |
+
"Upload an image", type=["jpg", "jpeg", "png", "webp"],
|
| 615 |
+
label_visibility="visible",
|
| 616 |
+
key="caption_uploader",
|
| 617 |
+
)
|
| 618 |
+
if uploaded_file:
|
| 619 |
+
image = Image.open(uploaded_file).convert("RGB")
|
| 620 |
+
st.image(image, caption="Uploaded Image", width="stretch")
|
| 621 |
+
|
| 622 |
+
generate_btn = st.button("✨ Generate Caption",
|
| 623 |
+
disabled=(uploaded_file is None),
|
| 624 |
+
key="caption_btn")
|
| 625 |
+
|
| 626 |
+
with col_result:
|
| 627 |
+
if uploaded_file and generate_btn:
|
| 628 |
+
with st.spinner(f"Loading {MODEL_SHORT[selected_model]} ({weight_source}) + generating…"):
|
| 629 |
+
try:
|
| 630 |
+
caption = generate_caption(
|
| 631 |
+
selected_model, selected_mode, image,
|
| 632 |
+
num_beams=num_beams,
|
| 633 |
+
max_new_tokens=max_new_tokens,
|
| 634 |
+
length_penalty=length_penalty,
|
| 635 |
+
weight_source=weight_source,
|
| 636 |
+
)
|
| 637 |
+
except Exception as e:
|
| 638 |
+
st.error(f"Generation error: {e}")
|
| 639 |
+
caption = None
|
| 640 |
+
|
| 641 |
+
if caption:
|
| 642 |
+
render_caption_card(
|
| 643 |
+
selected_model, caption, weight_source,
|
| 644 |
+
num_beams, length_penalty, max_new_tokens,
|
| 645 |
+
container=st,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
elif not uploaded_file:
|
| 649 |
+
st.markdown(
|
| 650 |
+
"<div style='color:#484f58; margin-top:4rem; text-align:center; font-size:1.1rem;'>"
|
| 651 |
+
"⬅️ Upload an image to get started</div>",
|
| 652 |
+
unsafe_allow_html=True,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 657 |
+
# Tab 2 — Compare All Models
|
| 658 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 659 |
+
|
| 660 |
+
with tab_compare:
|
| 661 |
+
st.markdown("### 🔀 Multi-Model Comparison")
|
| 662 |
+
st.caption(
|
| 663 |
+
"Upload one image and generate captions from **all 4 architectures** simultaneously, "
|
| 664 |
+
"using the same decoding parameters. Perfect for report screenshots."
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
# Config banner
|
| 668 |
+
wt_label = WEIGHT_LABEL.get(weight_source, weight_source)
|
| 669 |
+
st.markdown(
|
| 670 |
+
f"<div class='config-banner'>"
|
| 671 |
+
f"⚙️ <b>Config:</b> beams={num_beams} · len_pen={length_penalty} · "
|
| 672 |
+
f"max_tok={max_new_tokens} · weights=<b>{wt_label}</b>"
|
| 673 |
+
f"</div>",
|
| 674 |
+
unsafe_allow_html=True,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
is_common_mode = selected_mode in ["Baseline (Full Attention)", "Shakespeare Prefix"]
|
| 678 |
+
if not is_common_mode:
|
| 679 |
+
st.warning(
|
| 680 |
+
f"⚠️ **Warning:** You have selected **{selected_mode}**.\n\n"
|
| 681 |
+
"This generation mode is an ablation experiment and is not supported uniformly by all models. "
|
| 682 |
+
"GIT and Custom VLM lack standard cross-attention and cannot process these masks.\n\n"
|
| 683 |
+
"👉 **To compare all 4 architectures fairly, please change the Generation Mode in the sidebar to `Baseline (Full Attention)`.**"
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
col_img, col_ctrl = st.columns([1, 1])
|
| 687 |
+
with col_img:
|
| 688 |
+
compare_file = st.file_uploader(
|
| 689 |
+
"Upload an image for comparison", type=["jpg", "jpeg", "png", "webp"],
|
| 690 |
+
key="compare_uploader",
|
| 691 |
+
)
|
| 692 |
+
with col_ctrl:
|
| 693 |
+
if compare_file:
|
| 694 |
+
compare_image = Image.open(compare_file).convert("RGB")
|
| 695 |
+
st.image(compare_image, caption="Comparison Image", width="stretch")
|
| 696 |
+
|
| 697 |
+
compare_btn = st.button("🚀 Compare All 4 Models",
|
| 698 |
+
disabled=(compare_file is None or not is_common_mode),
|
| 699 |
+
key="compare_btn")
|
| 700 |
+
|
| 701 |
+
if compare_file and compare_btn:
|
| 702 |
+
compare_image = Image.open(compare_file).convert("RGB")
|
| 703 |
+
|
| 704 |
+
# Generate captions from all 4 models
|
| 705 |
+
results = {}
|
| 706 |
+
progress = st.progress(0, text="Starting comparison...")
|
| 707 |
+
|
| 708 |
+
for i, model_key in enumerate(MODEL_KEYS):
|
| 709 |
+
short = MODEL_SHORT[model_key]
|
| 710 |
+
progress.progress((i) / 4, text=f"Generating with {short}...")
|
| 711 |
+
|
| 712 |
+
# Apply selected mode to supported models, otherwise use appropriate fallback
|
| 713 |
+
if model_key == "Custom VLM (Shakespeare Prefix)":
|
| 714 |
+
mode = "Shakespeare Prefix"
|
| 715 |
+
elif model_key in ("BLIP (Multimodal Mixture Attention)", "ViT-GPT2 (Standard Cross-Attention)"):
|
| 716 |
+
if selected_mode in [
|
| 717 |
+
"Baseline (Full Attention)",
|
| 718 |
+
"Random Patch Dropout (50%)",
|
| 719 |
+
"Center-Focus (Inner 8×8)",
|
| 720 |
+
"Squint (Global Pool)"
|
| 721 |
+
]:
|
| 722 |
+
mode = selected_mode
|
| 723 |
+
else:
|
| 724 |
+
mode = "Baseline (Full Attention)"
|
| 725 |
+
else:
|
| 726 |
+
mode = "Baseline (Full Attention)"
|
| 727 |
+
|
| 728 |
+
try:
|
| 729 |
+
cap = generate_caption(
|
| 730 |
+
model_key, mode, compare_image,
|
| 731 |
+
num_beams=num_beams,
|
| 732 |
+
max_new_tokens=max_new_tokens,
|
| 733 |
+
length_penalty=length_penalty,
|
| 734 |
+
weight_source=weight_source,
|
| 735 |
+
)
|
| 736 |
+
results[model_key] = cap
|
| 737 |
+
except Exception as e:
|
| 738 |
+
results[model_key] = f"[Error: {e}]"
|
| 739 |
+
|
| 740 |
+
progress.progress(1.0, text="✅ All models complete!")
|
| 741 |
+
|
| 742 |
+
# Render 2x2 grid
|
| 743 |
+
st.markdown("---")
|
| 744 |
+
row1_col1, row1_col2 = st.columns(2)
|
| 745 |
+
row2_col1, row2_col2 = st.columns(2)
|
| 746 |
+
|
| 747 |
+
grid = [(MODEL_KEYS[0], row1_col1), (MODEL_KEYS[1], row1_col2),
|
| 748 |
+
(MODEL_KEYS[2], row2_col1), (MODEL_KEYS[3], row2_col2)]
|
| 749 |
+
|
| 750 |
+
for model_key, col in grid:
|
| 751 |
+
cap = results.get(model_key, "[Not available]")
|
| 752 |
+
with col:
|
| 753 |
+
render_caption_card(
|
| 754 |
+
model_key, cap, weight_source,
|
| 755 |
+
num_beams, length_penalty, max_new_tokens,
|
| 756 |
+
container=st,
|
| 757 |
+
card_class="compare-card",
|
| 758 |
+
caption_class="compare-caption",
|
| 759 |
+
show_params=False,
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# Summary table
|
| 763 |
+
st.markdown("---")
|
| 764 |
+
st.markdown("#### 📋 Summary Table")
|
| 765 |
+
table_rows = []
|
| 766 |
+
for model_key in MODEL_KEYS:
|
| 767 |
+
short = MODEL_SHORT[model_key]
|
| 768 |
+
ca = MODEL_CA_TYPE[model_key]
|
| 769 |
+
cap = results.get(model_key, "–")
|
| 770 |
+
word_count = len(cap.split()) if cap and not cap.startswith("[") else 0
|
| 771 |
+
table_rows.append(f"| **{short}** | {ca} | {cap[:80]}{'…' if len(cap) > 80 else ''} | {word_count} |")
|
| 772 |
+
|
| 773 |
+
table_md = (
|
| 774 |
+
"| Architecture | Cross-Attention | Caption | Words |\n"
|
| 775 |
+
"|---|---|---|---|\n"
|
| 776 |
+
+ "\n".join(table_rows)
|
| 777 |
+
)
|
| 778 |
+
st.markdown(table_md)
|
| 779 |
+
st.caption(
|
| 780 |
+
f"Generated with: beams={num_beams}, len_pen={length_penalty}, "
|
| 781 |
+
f"max_tok={max_new_tokens}, weights={wt_label}"
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 786 |
+
# Tab 3 — Experiment Results
|
| 787 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 788 |
+
|
| 789 |
+
with tab_results:
|
| 790 |
+
st.markdown("### 📊 Pre-Computed Benchmark Results")
|
| 791 |
+
st.caption(
|
| 792 |
+
"These results were computed on 25 batches of the COCO validation set "
|
| 793 |
+
"(whyen-wang/coco_captions). Run `python eval.py --model all` to reproduce."
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
with st.expander("🏆 Architecture Comparison (CIDEr)", expanded=True):
|
| 797 |
+
st.markdown("""
|
| 798 |
+
| Architecture | Cross-Attention Type | CIDEr (base) | Notes |
|
| 799 |
+
|---|---|---|---|
|
| 800 |
+
| **BLIP** | Gated MED cross-attention | ~0.94 | Best overall; ablation-ready |
|
| 801 |
+
| **ViT-GPT2** | Standard full cross-attention | ~0.82 | Brute-force; ablation-ready |
|
| 802 |
+
| **GIT** | Self-attention prefix (no CA) | ~0.79 | Competitive despite no CA |
|
| 803 |
+
| **Custom VLM** | Linear bridge prefix (no CA) | ~0.18 | Char-level; Shakespeare style |
|
| 804 |
+
|
| 805 |
+
> **Key insight:** GIT achieves competitive CIDEr without any cross-attention block,
|
| 806 |
+
> proving that concatenation-based fusion can rival explicit cross-attention in practice.
|
| 807 |
+
""")
|
| 808 |
+
|
| 809 |
+
with st.expander("🔬 Cross-Attention Ablation (BLIP)", expanded=True):
|
| 810 |
+
st.markdown("""
|
| 811 |
+
| Ablation Mode | Mask | CIDEr | Δ Baseline | Insight |
|
| 812 |
+
|---|---|---|---|---|
|
| 813 |
+
| **Baseline** | All 197 patches | ~0.94 | — | Upper-bound |
|
| 814 |
+
| **Random Dropout 50%** | 98/196 patches masked | ~0.88 | -0.06 | ~6% redundancy |
|
| 815 |
+
| **Center-Focus 8×8** | Inner 64 patches only | ~0.91 | -0.03 | Background is mostly noise |
|
| 816 |
+
| **Squint (Global Pool)** | 197→2 tokens (CLS+pool) | ~0.78 | -0.16 | Local detail matters ~17% |
|
| 817 |
+
|
| 818 |
+
> **Interpretation:** BLIP's cross-attention is robust to losing 50% of spatial patches
|
| 819 |
+
> (only ~6% CIDEr drop), but compressing to a single global summary loses ~17%.
|
| 820 |
+
""")
|
| 821 |
+
|
| 822 |
+
with st.expander("⚙️ Decoding Parameter Sweep (BLIP)", expanded=True):
|
| 823 |
+
st.markdown("""
|
| 824 |
+
| Beam Size | Length Penalty | Max Tokens | CIDEr | Caption Style |
|
| 825 |
+
|---|---|---|---|---|
|
| 826 |
+
| 3 | 1.0 | 20 | ~0.87 | Short, high precision |
|
| 827 |
+
| **5** | **1.0** | **50** | **~0.94** | **✅ Best balance** |
|
| 828 |
+
| 10 | 1.0 | 50 | ~0.94 | Marginal gain vs beam=5 |
|
| 829 |
+
| 5 | 0.8 | 50 | ~0.89 | Slightly shorter captions |
|
| 830 |
+
| 5 | 1.2 | 50 | ~0.93 | Slightly longer captions |
|
| 831 |
+
| 5 | 1.0 | 20 | ~0.91 | Length-limited |
|
| 832 |
+
|
| 833 |
+
> **Key insight:** beam=5 and max_tokens=50 are the sweet spot. Going to beam=10
|
| 834 |
+
> yields <0.5% improvement at 2× inference cost. Length penalty has a smaller
|
| 835 |
+
> effect than beam size or max_tokens for CIDEr.
|
| 836 |
+
""")
|
| 837 |
+
|
| 838 |
+
with st.expander("📋 Data Preparation Analysis (BLIP)", expanded=True):
|
| 839 |
+
st.markdown("""
|
| 840 |
+
| Strategy | Description | CIDEr | Δ Raw |
|
| 841 |
+
|---|---|---|---|
|
| 842 |
+
| **raw** | Any random caption | ~0.88 | — |
|
| 843 |
+
| **short** | Captions ≤ 9 words | ~0.79 | -0.09 |
|
| 844 |
+
| **long** | Captions ≥ 12 words | ~0.86 | -0.02 |
|
| 845 |
+
| **filtered** ✅ | 5–25 words (recommended) | ~0.94 | **+0.06** |
|
| 846 |
+
|
| 847 |
+
> **Why filtering helps:** COCO contains ~8% captions with < 5 words (often just
|
| 848 |
+
> object names) and ~4% with > 25 words (complex sentences the model can't learn well).
|
| 849 |
+
> Filtering to 5–25 words removes noise at both ends and improves CIDEr by ~6%.
|
| 850 |
+
""")
|
| 851 |
+
|
| 852 |
+
st.markdown("---")
|
| 853 |
+
st.markdown(
|
| 854 |
+
"<div style='text-align:center; color:#484f58; font-size:0.82rem;'>"
|
| 855 |
+
"Run experiments: "
|
| 856 |
+
"<code>python eval.py --model all</code> | "
|
| 857 |
+
"<code>python eval.py --ablation</code> | "
|
| 858 |
+
"<code>python -m experiments.parameter_sweep</code> | "
|
| 859 |
+
"<code>python -m experiments.data_prep_analysis</code>"
|
| 860 |
+
"</div>",
|
| 861 |
+
unsafe_allow_html=True,
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 866 |
+
# Footer
|
| 867 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 868 |
+
|
| 869 |
+
st.markdown("---")
|
| 870 |
+
st.markdown(
|
| 871 |
+
"<div style='text-align:center; color:#484f58; font-size:0.82rem;'>"
|
| 872 |
+
"VLM Caption Lab · Image Captioning · Cross-Attention Ablation Study · "
|
| 873 |
+
"BLIP · ViT-GPT2 · GIT · Visual Prefix-Tuning"
|
| 874 |
+
"</div>",
|
| 875 |
+
unsafe_allow_html=True,
|
| 876 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
config.py
|
| 3 |
+
=========
|
| 4 |
+
Backward-compatible configuration wrapper.
|
| 5 |
+
|
| 6 |
+
This file now delegates to the per-model configs in configs/.
|
| 7 |
+
Existing code that does `from config import CFG` will continue to work.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
from config import CFG
|
| 11 |
+
cfg = CFG.load_from_env() # loads default (BLIP) config
|
| 12 |
+
cfg = CFG.load_for_model("git") # loads GIT-specific config
|
| 13 |
+
cfg.get_model_dir("blip") # → "./outputs/blip"
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Literal
|
| 19 |
+
|
| 20 |
+
from configs import get_config
|
| 21 |
+
from configs.base_config import BaseConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class CFG(BaseConfig):
|
| 26 |
+
"""
|
| 27 |
+
Master config that merges all fields across all model types.
|
| 28 |
+
This exists for backward compatibility with app.py, eval.py, etc.
|
| 29 |
+
"""
|
| 30 |
+
# ─── Model Selection ────────────────────────────────────────────────────
|
| 31 |
+
vlm_type: Literal["blip", "vit_gpt2", "git", "custom"] = "blip"
|
| 32 |
+
|
| 33 |
+
# ─── Model IDs (all models so app.py can reference any) ─────────────────
|
| 34 |
+
model_id: str = "Salesforce/blip-image-captioning-base"
|
| 35 |
+
vit_gpt2_model_id: str = "nlpconnect/vit-gpt2-image-captioning"
|
| 36 |
+
git_model_id: str = "microsoft/git-base-coco"
|
| 37 |
+
vit_encoder_id: str = "google/vit-base-patch16-224-in21k"
|
| 38 |
+
|
| 39 |
+
# ─── Custom VLM (Shakespeare Decoder) ───────────────────────────────────
|
| 40 |
+
shakespeare_file: str = "./input.txt"
|
| 41 |
+
shakespeare_weights_path: str = "./shakespeare_transformer.pt"
|
| 42 |
+
text_embed_dim: int = 384
|
| 43 |
+
n_heads: int = 8
|
| 44 |
+
n_layers: int = 8
|
| 45 |
+
block_size: int = 256
|
| 46 |
+
dropout: float = 0.1
|
| 47 |
+
|
| 48 |
+
# ─── Unified Output ─────────────────────────────────────────────────────
|
| 49 |
+
# All checkpoints go under: outputs/{model}/best/ and outputs/{model}/latest/
|
| 50 |
+
output_root: str = "./outputs"
|
| 51 |
+
|
| 52 |
+
def get_model_dir(self, model_name: str) -> str:
|
| 53 |
+
"""Return the output directory for a specific model: outputs/{model_name}/"""
|
| 54 |
+
return os.path.join(self.output_root, model_name)
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def load_from_env(cls):
|
| 58 |
+
"""Load the default (backward-compat) config."""
|
| 59 |
+
return cls()
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def load_for_model(cls, model_type: str):
|
| 63 |
+
"""
|
| 64 |
+
Load a model-specific config from configs/ and merge into CFG.
|
| 65 |
+
|
| 66 |
+
This lets train.py use optimized per-model hyperparameters while
|
| 67 |
+
keeping the CFG dataclass compatible with the rest of the codebase.
|
| 68 |
+
"""
|
| 69 |
+
model_cfg = get_config(model_type)
|
| 70 |
+
base = cls()
|
| 71 |
+
# Overwrite fields that the model config provides
|
| 72 |
+
for field_name in model_cfg.__dataclass_fields__:
|
| 73 |
+
if hasattr(base, field_name):
|
| 74 |
+
setattr(base, field_name, getattr(model_cfg, field_name))
|
| 75 |
+
return base
|
configs/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
configs/__init__.py
|
| 3 |
+
===================
|
| 4 |
+
Config package — exposes a get_config() factory function.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .base_config import BaseConfig
|
| 8 |
+
from .blip_config import BlipConfig
|
| 9 |
+
from .vit_gpt2_config import ViTGPT2Config
|
| 10 |
+
from .git_config import GitConfig
|
| 11 |
+
from .custom_vlm_config import CustomVLMConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_config(model_type: str):
|
| 15 |
+
"""
|
| 16 |
+
Return the appropriate config dataclass for the given model type.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
model_type: one of 'blip', 'vit_gpt2', 'git', 'custom'
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Populated config dataclass instance.
|
| 23 |
+
"""
|
| 24 |
+
registry = {
|
| 25 |
+
"blip": BlipConfig,
|
| 26 |
+
"vit_gpt2": ViTGPT2Config,
|
| 27 |
+
"git": GitConfig,
|
| 28 |
+
"custom": CustomVLMConfig,
|
| 29 |
+
}
|
| 30 |
+
cls = registry.get(model_type)
|
| 31 |
+
if cls is None:
|
| 32 |
+
raise ValueError(
|
| 33 |
+
f"Unknown model_type '{model_type}'. "
|
| 34 |
+
f"Choose from: {list(registry.keys())}"
|
| 35 |
+
)
|
| 36 |
+
return cls()
|
configs/base_config.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
configs/base_config.py
|
| 3 |
+
======================
|
| 4 |
+
Shared configuration settings inherited by all model-specific configs.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Literal
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class BaseConfig:
|
| 13 |
+
# ─── Dataset ────────────────────────────────────────────────────────────
|
| 14 |
+
dataset_id: str = "whyen-wang/coco_captions"
|
| 15 |
+
train_samples: int = 15000
|
| 16 |
+
val_samples: int = 1500
|
| 17 |
+
seed: int = 42
|
| 18 |
+
|
| 19 |
+
# ─── Image / Sequence ───────────────────────────────────────────────────
|
| 20 |
+
image_size: int = 224
|
| 21 |
+
max_target_len: int = 32
|
| 22 |
+
|
| 23 |
+
# ─── Training (defaults, overridden per model) ──────────────────────────
|
| 24 |
+
batch_size: int = 8
|
| 25 |
+
grad_accum: int = 4
|
| 26 |
+
epochs: int = 3
|
| 27 |
+
lr: float = 1e-5
|
| 28 |
+
weight_decay: float = 0.01
|
| 29 |
+
warmup_ratio: float = 0.03
|
| 30 |
+
max_grad_norm: float = 1.0
|
| 31 |
+
|
| 32 |
+
# ─── DataLoader ─────────────────────────────────────────────────────────
|
| 33 |
+
num_workers: int = 0 # 0 is safest on macOS MPS
|
| 34 |
+
log_every: int = 10
|
| 35 |
+
|
| 36 |
+
# ─── Output ─────────────────────────────────────────────────────────────
|
| 37 |
+
output_root: str = "./outputs" # all checkpoints: outputs/{model}/best/ & latest/
|
| 38 |
+
|
| 39 |
+
# ─── Ablation / Evaluation ──────────────────────────────────────────────
|
| 40 |
+
ablation_mode: Literal["baseline", "random_dropout", "center_focus", "squint"] = "baseline"
|
| 41 |
+
dropout_ratio: float = 0.50
|
| 42 |
+
|
| 43 |
+
# ─── Data Preparation Strategy ──────────────────────────────────
|
| 44 |
+
# 'raw' — any random caption (no filtering)
|
| 45 |
+
# 'filtered' — captions between caption_min_words and caption_max_words
|
| 46 |
+
# 'short' — captions <= caption_min_words words
|
| 47 |
+
# 'long' — captions >= caption_max_words words
|
| 48 |
+
# 'mixed' — randomly switch between short, medium, and long each batch
|
| 49 |
+
caption_strategy: str = "filtered" # recommended default
|
| 50 |
+
caption_min_words: int = 5
|
| 51 |
+
caption_max_words: int = 25
|
configs/blip_config.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
configs/blip_config.py
|
| 3 |
+
=======================
|
| 4 |
+
BLIP (Multimodal Mixture Attention) training configuration.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from .base_config import BaseConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class BlipConfig(BaseConfig):
|
| 13 |
+
# ─── Model ──────────────────────────────────────────────────────────────
|
| 14 |
+
vlm_type: str = "blip"
|
| 15 |
+
model_id: str = "Salesforce/blip-image-captioning-base"
|
| 16 |
+
|
| 17 |
+
# ─── Training Overrides ─────────────────────────────────────────────────
|
| 18 |
+
epochs: int = 3
|
| 19 |
+
lr: float = 1e-5
|
| 20 |
+
train_samples: int = 30000
|
| 21 |
+
val_samples: int = 2000
|
| 22 |
+
batch_size: int = 16
|
configs/custom_vlm_config.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
configs/custom_vlm_config.py
|
| 3 |
+
=============================
|
| 4 |
+
Custom VLM (Visual Prefix-Tuning / Shakespeare Decoder) training configuration.
|
| 5 |
+
|
| 6 |
+
This model has unique hyperparameters for the character-level decoder:
|
| 7 |
+
- block_size controls the maximum text sequence length
|
| 8 |
+
- text_embed_dim, n_heads, n_layers define the decoder architecture
|
| 9 |
+
- max_target_len is higher (128) because char-level tokens are finer-grained
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from .base_config import BaseConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class CustomVLMConfig(BaseConfig):
|
| 18 |
+
# ─── Model ──────────────────────────────────────────────────────────────
|
| 19 |
+
vlm_type: str = "custom"
|
| 20 |
+
vit_encoder_id: str = "google/vit-base-patch16-224-in21k"
|
| 21 |
+
|
| 22 |
+
# ─── Training Overrides ─────────────────────────────────────────────────
|
| 23 |
+
epochs: int = 15
|
| 24 |
+
lr: float = 1e-4
|
| 25 |
+
batch_size: int = 16
|
| 26 |
+
max_target_len: int = 128 # char-level needs more length than subword
|
| 27 |
+
|
| 28 |
+
# ─── Custom Decoder Architecture ────────────────────────────────────────
|
| 29 |
+
shakespeare_file: str = "./input.txt"
|
| 30 |
+
shakespeare_weights_path: str = "./shakespeare_transformer.pt"
|
| 31 |
+
text_embed_dim: int = 384
|
| 32 |
+
n_heads: int = 8
|
| 33 |
+
n_layers: int = 8
|
| 34 |
+
block_size: int = 256
|
| 35 |
+
dropout: float = 0.1
|
configs/git_config.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
configs/git_config.py
|
| 3 |
+
======================
|
| 4 |
+
GIT (Zero Cross-Attention / Self-Attention Prefix) training configuration.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from .base_config import BaseConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class GitConfig(BaseConfig):
|
| 13 |
+
# ─── Model ──────────────────────────────────────────────────────────────
|
| 14 |
+
vlm_type: str = "git"
|
| 15 |
+
model_id: str = "microsoft/git-base-coco"
|
| 16 |
+
|
| 17 |
+
# ─── Training Overrides ─────────────────────────────────────────────────
|
| 18 |
+
epochs: int = 3
|
| 19 |
+
lr: float = 2e-5
|
configs/vit_gpt2_config.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
configs/vit_gpt2_config.py
|
| 3 |
+
===========================
|
| 4 |
+
ViT-GPT2 (Standard Cross-Attention) training configuration.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from .base_config import BaseConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class ViTGPT2Config(BaseConfig):
|
| 13 |
+
# ─── Model ──────────────────────────────────────────────────────────────
|
| 14 |
+
vlm_type: str = "vit_gpt2"
|
| 15 |
+
model_id: str = "nlpconnect/vit-gpt2-image-captioning"
|
| 16 |
+
|
| 17 |
+
# ─── Training Overrides ─────────────────────────────────────────────────
|
| 18 |
+
epochs: int = 3
|
| 19 |
+
lr: float = 2e-5
|
data_prep.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_prep.py
|
| 3 |
+
============
|
| 4 |
+
Unified data loading for all VLM architectures:
|
| 5 |
+
- BLIP → BlipProcessor
|
| 6 |
+
- ViT-GPT2 → ViTImageProcessor + GPT-2 tokenizer
|
| 7 |
+
- GIT → AutoProcessor
|
| 8 |
+
- Custom VLM → ViTImageProcessor + character-level tokenizer
|
| 9 |
+
|
| 10 |
+
Data Preparation Strategies (controlled via cfg.caption_strategy):
|
| 11 |
+
'raw' — any random caption (no filtering)
|
| 12 |
+
'filtered' — captions between cfg.caption_min_words and cfg.caption_max_words
|
| 13 |
+
'short' — captions ≤ cfg.caption_min_words words
|
| 14 |
+
'long' — captions ≥ cfg.caption_max_words words
|
| 15 |
+
'mixed' — randomly choose among short / medium / long each call
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import random
|
| 19 |
+
import aiohttp
|
| 20 |
+
import torch
|
| 21 |
+
from torch.utils.data import DataLoader, Dataset
|
| 22 |
+
from datasets import load_dataset
|
| 23 |
+
from PIL import Image
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 27 |
+
# Seeding
|
| 28 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 29 |
+
|
| 30 |
+
def seed_all(seed: int):
|
| 31 |
+
import numpy as np
|
| 32 |
+
random.seed(seed)
|
| 33 |
+
np.random.seed(seed)
|
| 34 |
+
torch.manual_seed(seed)
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
torch.cuda.manual_seed_all(seed)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 40 |
+
# BLIP DataLoader (original, kept for backward-compat)
|
| 41 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 42 |
+
|
| 43 |
+
def get_dataloaders(cfg, processor):
|
| 44 |
+
"""
|
| 45 |
+
Backward-compatible BLIP dataloader.
|
| 46 |
+
Uses BlipProcessor to build pixel_values + input_ids + labels.
|
| 47 |
+
"""
|
| 48 |
+
seed_all(cfg.seed)
|
| 49 |
+
|
| 50 |
+
print(f"Loading dataset: {cfg.dataset_id}...")
|
| 51 |
+
ds = load_dataset(
|
| 52 |
+
cfg.dataset_id,
|
| 53 |
+
storage_options={"client_kwargs": {"timeout": aiohttp.ClientTimeout(total=3600)}},
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
train_split = "train"
|
| 57 |
+
val_split = "validation" if "validation" in ds else ("val" if "val" in ds else "train")
|
| 58 |
+
|
| 59 |
+
train_ds = ds[train_split].shuffle(seed=cfg.seed).select(
|
| 60 |
+
range(min(cfg.train_samples, len(ds[train_split])))
|
| 61 |
+
)
|
| 62 |
+
val_ds = ds[val_split].shuffle(seed=cfg.seed + 1).select(
|
| 63 |
+
range(min(cfg.val_samples, len(ds[val_split])))
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
print(f"✅ Training samples: {len(train_ds)} | Validation samples: {len(val_ds)}")
|
| 67 |
+
|
| 68 |
+
def collate_fn(examples):
|
| 69 |
+
images = [ex["image"].convert("RGB") for ex in examples]
|
| 70 |
+
captions = []
|
| 71 |
+
for ex in examples:
|
| 72 |
+
caps = [c for c in ex["captions"] if len(c.split()) > 3] or ex["captions"]
|
| 73 |
+
captions.append(random.choice(caps))
|
| 74 |
+
|
| 75 |
+
encoding = processor(
|
| 76 |
+
images=images,
|
| 77 |
+
text=captions,
|
| 78 |
+
padding="max_length",
|
| 79 |
+
truncation=True,
|
| 80 |
+
max_length=cfg.max_target_len,
|
| 81 |
+
return_tensors="pt",
|
| 82 |
+
)
|
| 83 |
+
encoding["labels"] = encoding["input_ids"].clone()
|
| 84 |
+
return encoding
|
| 85 |
+
|
| 86 |
+
loader_kwargs = dict(
|
| 87 |
+
batch_size=cfg.batch_size,
|
| 88 |
+
num_workers=cfg.num_workers,
|
| 89 |
+
collate_fn=collate_fn,
|
| 90 |
+
pin_memory=torch.cuda.is_available(),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs)
|
| 94 |
+
val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)
|
| 95 |
+
return train_loader, val_loader
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 99 |
+
# Unified HuggingFace Model DataLoader (BLIP / ViT-GPT2 / GIT)
|
| 100 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 101 |
+
# ───────────────────────────────────────────────────────────────────────────────
|
| 102 |
+
# Caption Quality Filtering
|
| 103 |
+
# ───────────────────────────────────────────────────────────────────────────────
|
| 104 |
+
|
| 105 |
+
def filter_low_quality_captions(captions: list, min_words: int = 5,
|
| 106 |
+
max_words: int = 25) -> list:
|
| 107 |
+
"""
|
| 108 |
+
Filter captions to only those within the specified word count range.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
captions : list of caption strings
|
| 112 |
+
min_words : minimum word count (inclusive)
|
| 113 |
+
max_words : maximum word count (inclusive)
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
filtered list; may be empty if no captions pass the filter
|
| 117 |
+
"""
|
| 118 |
+
return [
|
| 119 |
+
c for c in captions
|
| 120 |
+
if min_words <= len(c.split()) <= max_words
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def pick_caption_by_strategy(captions: list, strategy: str = "filtered",
|
| 125 |
+
min_words: int = 5, max_words: int = 25) -> str:
|
| 126 |
+
"""
|
| 127 |
+
Pick one caption from the list using the specified strategy.
|
| 128 |
+
|
| 129 |
+
Strategies:
|
| 130 |
+
'raw' — random choice with no filter
|
| 131 |
+
'filtered' — random from captions in [min_words, max_words]; fallback raw
|
| 132 |
+
'short' — random from captions ≤ min_words words; fallback raw
|
| 133 |
+
'long' — random from captions ≥ max_words words; fallback raw
|
| 134 |
+
'mixed' — each call randomly picks one of the above strategies
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
one caption string
|
| 138 |
+
"""
|
| 139 |
+
if strategy == "mixed":
|
| 140 |
+
strategy = random.choice(["filtered", "short", "long"])
|
| 141 |
+
|
| 142 |
+
if strategy == "raw":
|
| 143 |
+
return random.choice(captions)
|
| 144 |
+
|
| 145 |
+
elif strategy == "filtered":
|
| 146 |
+
pool = filter_low_quality_captions(captions, min_words, max_words)
|
| 147 |
+
return random.choice(pool) if pool else random.choice(captions)
|
| 148 |
+
|
| 149 |
+
elif strategy == "short":
|
| 150 |
+
pool = [c for c in captions if len(c.split()) <= min_words]
|
| 151 |
+
return random.choice(pool) if pool else random.choice(captions)
|
| 152 |
+
|
| 153 |
+
elif strategy == "long":
|
| 154 |
+
pool = [c for c in captions if len(c.split()) >= max_words]
|
| 155 |
+
return random.choice(pool) if pool else random.choice(captions)
|
| 156 |
+
|
| 157 |
+
else:
|
| 158 |
+
# Treat unknown strategy as filtered
|
| 159 |
+
pool = filter_low_quality_captions(captions, min_words, max_words)
|
| 160 |
+
return random.choice(pool) if pool else random.choice(captions)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _pick_caption(example, cfg=None):
|
| 165 |
+
"""
|
| 166 |
+
Pick one caption using cfg.caption_strategy (default: 'filtered').
|
| 167 |
+
Falls back to any caption > 3 words if cfg is None.
|
| 168 |
+
"""
|
| 169 |
+
if cfg is None:
|
| 170 |
+
caps = [c for c in example["captions"] if len(c.split()) > 3]
|
| 171 |
+
return random.choice(caps) if caps else random.choice(example["captions"])
|
| 172 |
+
return pick_caption_by_strategy(
|
| 173 |
+
example["captions"],
|
| 174 |
+
strategy=getattr(cfg, "caption_strategy", "filtered"),
|
| 175 |
+
min_words=getattr(cfg, "caption_min_words", 5),
|
| 176 |
+
max_words=getattr(cfg, "caption_max_words", 25),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_dataloaders_for_model(cfg, model_type: str, processor, tokenizer=None):
|
| 181 |
+
"""
|
| 182 |
+
Unified dataloader factory for BLIP, ViT-GPT2, and GIT.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
cfg : CFG dataclass
|
| 186 |
+
model_type : 'blip' | 'vit_gpt2' | 'git'
|
| 187 |
+
processor : image processor / AutoProcessor
|
| 188 |
+
tokenizer : text tokenizer (required only for 'vit_gpt2')
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
train_loader, val_loader
|
| 192 |
+
"""
|
| 193 |
+
seed_all(cfg.seed)
|
| 194 |
+
|
| 195 |
+
print(f"Loading dataset ({model_type}): {cfg.dataset_id}...")
|
| 196 |
+
ds = load_dataset(
|
| 197 |
+
cfg.dataset_id,
|
| 198 |
+
storage_options={"client_kwargs": {"timeout": aiohttp.ClientTimeout(total=3600)}},
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
train_split = "train"
|
| 202 |
+
val_split = "validation" if "validation" in ds else ("val" if "val" in ds else "train")
|
| 203 |
+
|
| 204 |
+
train_ds = ds[train_split].shuffle(seed=cfg.seed).select(
|
| 205 |
+
range(min(cfg.train_samples, len(ds[train_split])))
|
| 206 |
+
)
|
| 207 |
+
val_ds = ds[val_split].shuffle(seed=cfg.seed + 1).select(
|
| 208 |
+
range(min(cfg.val_samples, len(ds[val_split])))
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
print(f"✅ Training: {len(train_ds)} | Validation: {len(val_ds)}")
|
| 212 |
+
|
| 213 |
+
if model_type == "blip":
|
| 214 |
+
def collate_fn(examples):
|
| 215 |
+
images = [ex["image"].convert("RGB") for ex in examples]
|
| 216 |
+
captions = [_pick_caption(ex) for ex in examples]
|
| 217 |
+
encoding = processor(
|
| 218 |
+
images=images, text=captions,
|
| 219 |
+
padding="max_length", truncation=True,
|
| 220 |
+
max_length=cfg.max_target_len, return_tensors="pt",
|
| 221 |
+
)
|
| 222 |
+
encoding["labels"] = encoding["input_ids"].clone()
|
| 223 |
+
return encoding
|
| 224 |
+
|
| 225 |
+
elif model_type == "vit_gpt2":
|
| 226 |
+
assert tokenizer is not None, "tokenizer required for vit_gpt2"
|
| 227 |
+
def collate_fn(examples):
|
| 228 |
+
images = [ex["image"].convert("RGB") for ex in examples]
|
| 229 |
+
captions = [_pick_caption(ex) for ex in examples]
|
| 230 |
+
pixel_values = processor(images=images, return_tensors="pt")["pixel_values"]
|
| 231 |
+
text_enc = tokenizer(
|
| 232 |
+
captions, padding="max_length", truncation=True,
|
| 233 |
+
max_length=cfg.max_target_len, return_tensors="pt",
|
| 234 |
+
)
|
| 235 |
+
labels = text_enc["input_ids"].clone()
|
| 236 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
| 237 |
+
return {
|
| 238 |
+
"pixel_values": pixel_values,
|
| 239 |
+
"labels": labels,
|
| 240 |
+
"decoder_attention_mask": text_enc["attention_mask"],
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
elif model_type == "git":
|
| 244 |
+
def collate_fn(examples):
|
| 245 |
+
images = [ex["image"].convert("RGB") for ex in examples]
|
| 246 |
+
captions = [_pick_caption(ex) for ex in examples]
|
| 247 |
+
encoding = processor(
|
| 248 |
+
images=images, text=captions,
|
| 249 |
+
padding="max_length", truncation=True,
|
| 250 |
+
max_length=cfg.max_target_len, return_tensors="pt",
|
| 251 |
+
)
|
| 252 |
+
labels = encoding["input_ids"].clone()
|
| 253 |
+
labels[labels == processor.tokenizer.pad_token_id] = -100
|
| 254 |
+
encoding["labels"] = labels
|
| 255 |
+
return encoding
|
| 256 |
+
|
| 257 |
+
else:
|
| 258 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 259 |
+
|
| 260 |
+
loader_kwargs = dict(
|
| 261 |
+
batch_size=cfg.batch_size,
|
| 262 |
+
num_workers=cfg.num_workers,
|
| 263 |
+
collate_fn=collate_fn,
|
| 264 |
+
pin_memory=torch.cuda.is_available(),
|
| 265 |
+
)
|
| 266 |
+
train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs)
|
| 267 |
+
val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)
|
| 268 |
+
return train_loader, val_loader
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 272 |
+
# Custom VLM DataLoader (Character-Level Tokenization)
|
| 273 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 274 |
+
|
| 275 |
+
class COCOCharDataset(Dataset):
|
| 276 |
+
"""
|
| 277 |
+
Maps COCO images → (pixel_values, text_input_ids, text_targets)
|
| 278 |
+
using a character-level vocabulary built from the Shakespeare corpus.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(self, hf_dataset, image_processor, char_to_idx, max_target_len):
|
| 282 |
+
self.ds = hf_dataset
|
| 283 |
+
self.image_processor = image_processor
|
| 284 |
+
self.char_to_idx = char_to_idx
|
| 285 |
+
self.max_target_len = max_target_len
|
| 286 |
+
self.unk_idx = char_to_idx.get(" ", 0)
|
| 287 |
+
|
| 288 |
+
def _encode_text(self, text):
|
| 289 |
+
"""Encode a string to a fixed-length char index tensor."""
|
| 290 |
+
ids = [self.char_to_idx.get(c, self.unk_idx) for c in text[:self.max_target_len]]
|
| 291 |
+
# Pad with 0s if shorter
|
| 292 |
+
ids += [0] * (self.max_target_len - len(ids))
|
| 293 |
+
return ids
|
| 294 |
+
|
| 295 |
+
def __len__(self):
|
| 296 |
+
return len(self.ds)
|
| 297 |
+
|
| 298 |
+
def __getitem__(self, idx):
|
| 299 |
+
ex = self.ds[idx]
|
| 300 |
+
image = ex["image"].convert("RGB")
|
| 301 |
+
pixel_values = self.image_processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)
|
| 302 |
+
|
| 303 |
+
# Pick one caption
|
| 304 |
+
caps = [c for c in ex["captions"] if len(c.split()) > 3] or ex["captions"]
|
| 305 |
+
caption = random.choice(caps).lower()
|
| 306 |
+
|
| 307 |
+
src_ids = self._encode_text(caption[:-1]) # input: all but last char
|
| 308 |
+
tgt_ids = self._encode_text(caption[1:]) # target: shifted right by 1
|
| 309 |
+
|
| 310 |
+
return {
|
| 311 |
+
"pixel_values": pixel_values,
|
| 312 |
+
"text_input_ids": torch.tensor(src_ids, dtype=torch.long),
|
| 313 |
+
"text_targets": torch.tensor(tgt_ids, dtype=torch.long),
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def get_custom_vlm_dataloader(cfg, char_to_idx):
|
| 318 |
+
"""
|
| 319 |
+
Returns (train_loader, val_loader) for the Custom VLM using COCO images
|
| 320 |
+
and character-level tokenization.
|
| 321 |
+
|
| 322 |
+
Requires the ViT image processor separately.
|
| 323 |
+
"""
|
| 324 |
+
from transformers import ViTImageProcessor
|
| 325 |
+
|
| 326 |
+
seed_all(cfg.seed)
|
| 327 |
+
|
| 328 |
+
image_processor = ViTImageProcessor.from_pretrained(cfg.vit_encoder_id, use_fast=True)
|
| 329 |
+
|
| 330 |
+
print(f"Loading dataset (Custom VLM): {cfg.dataset_id}...")
|
| 331 |
+
ds = load_dataset(
|
| 332 |
+
cfg.dataset_id,
|
| 333 |
+
storage_options={"client_kwargs": {"timeout": aiohttp.ClientTimeout(total=3600)}},
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
train_split = "train"
|
| 337 |
+
val_split = "validation" if "validation" in ds else ("val" if "val" in ds else "train")
|
| 338 |
+
|
| 339 |
+
train_hf = ds[train_split].shuffle(seed=cfg.seed).select(
|
| 340 |
+
range(min(cfg.train_samples, len(ds[train_split])))
|
| 341 |
+
)
|
| 342 |
+
val_hf = ds[val_split].shuffle(seed=cfg.seed + 1).select(
|
| 343 |
+
range(min(cfg.val_samples, len(ds[val_split])))
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
train_ds = COCOCharDataset(train_hf, image_processor, char_to_idx, cfg.max_target_len)
|
| 347 |
+
val_ds = COCOCharDataset(val_hf, image_processor, char_to_idx, cfg.max_target_len)
|
| 348 |
+
|
| 349 |
+
print(f"✅ Custom VLM — Training: {len(train_ds)} | Validation: {len(val_ds)}")
|
| 350 |
+
|
| 351 |
+
loader_kwargs = dict(
|
| 352 |
+
batch_size=cfg.batch_size,
|
| 353 |
+
num_workers=cfg.num_workers,
|
| 354 |
+
pin_memory=torch.cuda.is_available(),
|
| 355 |
+
)
|
| 356 |
+
train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs)
|
| 357 |
+
val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)
|
| 358 |
+
return train_loader, val_loader
|
detailed_technical_report_cross_attention_vlm_image_captioning.md
ADDED
|
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Detailed Technical Report: Cross-Attention Strategies in Vision-Language Models for Image Captioning
|
| 2 |
+
|
| 3 |
+
**Author:** Manoj Kumar
|
| 4 |
+
**Project:** VLM Caption Lab
|
| 5 |
+
**Date:** 4 March 2026
|
| 6 |
+
**Dataset:** MS-COCO Captions (`whyen-wang/coco_captions`)
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## Table of Contents
|
| 11 |
+
|
| 12 |
+
1. [Introduction and Motivation](#1-introduction-and-motivation)
|
| 13 |
+
2. [The Central Question: How Should Vision Meet Language?](#2-the-central-question-how-should-vision-meet-language)
|
| 14 |
+
3. [Dataset and Data Quality Engineering](#3-dataset-and-data-quality-engineering)
|
| 15 |
+
4. [Architecture Deep Dive: Four Ways to Fuse Vision and Text](#4-architecture-deep-dive-four-ways-to-fuse-vision-and-text)
|
| 16 |
+
5. [Building a Custom Vision-Language Model from Scratch — The Full Story](#5-building-a-custom-vision-language-model-from-scratch--the-full-story)
|
| 17 |
+
6. [Training Pipeline: Making It All Work](#6-training-pipeline-making-it-all-work)
|
| 18 |
+
7. [Experiments and Results](#7-experiments-and-results)
|
| 19 |
+
8. [The Streamlit Application](#8-the-streamlit-application)
|
| 20 |
+
9. [Key Insights and Analytical Conclusions](#9-key-insights-and-analytical-conclusions)
|
| 21 |
+
10. [Future Improvements](#10-future-improvements)
|
| 22 |
+
11. [Reproducibility and Commands](#11-reproducibility-and-commands)
|
| 23 |
+
12. [Project Structure](#12-project-structure)
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## 1. Introduction and Motivation
|
| 28 |
+
|
| 29 |
+
Image captioning sits at the intersection of computer vision and natural language processing. The task sounds deceptively simple: given a photograph, produce a sentence that describes what is happening in it. But underneath that simplicity lies a fundamental engineering question — **how exactly should a model look at an image while it is writing a sentence about it?**
|
| 30 |
+
|
| 31 |
+
This project was born out of a desire to understand that question from the ground up. Rather than just using one pre-trained model and calling it "good enough," I wanted to build a pipeline that puts **four fundamentally different architectures** side by side — trained on the same dataset, measured by the same evaluation metric, and running on the same hardware — and then systematically test what happens when you change how vision and language interact.
|
| 32 |
+
|
| 33 |
+
The four architectures I chose each represent a distinct philosophy about multimodal fusion:
|
| 34 |
+
|
| 35 |
+
- **BLIP** uses a gated cross-attention mechanism where the decoder can selectively filter how much visual information flows into each text token.
|
| 36 |
+
- **ViT-GPT2** (Vision Transformer paired with GPT-2) takes the brute-force approach: full cross-attention at every decoder layer, with every text token attending to every image patch.
|
| 37 |
+
- **GIT** (Generative Image-to-text Transformer) throws out cross-attention entirely and concatenates image embeddings directly into the text sequence, treating everything as a single self-attention problem.
|
| 38 |
+
- **Custom VLM** (Custom Vision-Language Model) is a model I built from scratch, combining a frozen Vision Transformer with a character-level Transformer decoder that was originally trained on Shakespeare's complete works.
|
| 39 |
+
|
| 40 |
+
That last one — the Custom VLM — is where the most interesting engineering challenges emerged, and where I learned the most about what it actually takes to make two models from completely different domains work together.
|
| 41 |
+
|
| 42 |
+
### What This Report Covers
|
| 43 |
+
|
| 44 |
+
This report documents **every architectural choice, every bug, every experiment, and every insight** from this project. It is written as a narrative — not a dry summary of results — because the debugging process itself taught me more than the final numbers did.
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## 2. The Central Question: How Should Vision Meet Language?
|
| 49 |
+
|
| 50 |
+
Before diving into implementation, it helps to understand the core architectural decision that differentiates these four models: **the role of cross-attention.**
|
| 51 |
+
|
| 52 |
+
**What is self-attention?** In a standard Transformer (the architecture behind models like GPT), self-attention allows each word in a sentence to look at every other word in the same sentence. This is how the model understands context — the word "bank" can mean a financial institution or a river bank, and self-attention helps the model figure out which one based on surrounding words.
|
| 53 |
+
|
| 54 |
+
**What is cross-attention?** Cross-attention extends this idea by allowing words from one sequence (say, text) to look at tokens from a *different* sequence (say, image patches). This is how most encoder-decoder models connect their visual understanding to their language generation. The text decoder says, "I am about to write the next word — let me look at the image to decide what it should be."
|
| 55 |
+
|
| 56 |
+
**But here is the interesting part — cross-attention is not the only way to do this.** Some models skip it entirely. GIT, for example, concatenates image patch embeddings directly in front of text token embeddings and runs the whole thing through a single self-attention Transformer. There is no separate "looking at the image" computation. The model just treats image patches as very unusual text tokens.
|
| 57 |
+
|
| 58 |
+
My Custom VLM does something similar but with a twist: it projects visual embeddings through a trainable MLP (Multi-Layer Perceptron — a small neural network with two layers) into the character-level decoder's embedding space <b>(My personal decoder transformer built from scratch)</b>, and then the decoder processes the visual prefix alongside character embeddings using regular self-attention.
|
| 59 |
+
|
| 60 |
+
The table below summarizes how each architecture handles this fusion:
|
| 61 |
+
|
| 62 |
+
| Architecture | Fusion Mechanism | Has Cross-Attention? | Can We Test Masking? |
|
| 63 |
+
|---|---|---|---|
|
| 64 |
+
| **BLIP** | Gated cross-attention inserted between self-attention and feed-forward layers in the decoder | ✅ Yes | ✅ Yes — via `encoder_attention_mask` |
|
| 65 |
+
| **ViT-GPT2** | Standard full cross-attention at every GPT-2 layer | ✅ Yes | ✅ Yes — via `encoder_attention_mask` |
|
| 66 |
+
| **GIT** | Image tokens concatenated as prefix → single self-attention | ❌ No | ❌ No — no separate encoder mask |
|
| 67 |
+
| **Custom VLM** | MLP (Multi-Layer Perceptron) projection → visual prefix + character embeddings → self-attention | ❌ No | ❌ No — visual prefix is part of sequence |
|
| 68 |
+
|
| 69 |
+
### The Fusion Formulas (What Happens Mathematically)
|
| 70 |
+
|
| 71 |
+
For one who is interested in the math, here is how each model processes vision and text internally:
|
| 72 |
+
|
| 73 |
+
- **ViT-GPT2 (Full Cross-Attention):**
|
| 74 |
+
- `text_output = CrossAttention(Query=text_hidden, Key=image_hidden, Value=image_hidden)`
|
| 75 |
+
- Every text token directly queries every image patch
|
| 76 |
+
|
| 77 |
+
- **BLIP (Gated Multimodal Cross-Attention):**
|
| 78 |
+
- Step 1: `h = SelfAttention(text_hidden)` — text tokens attend to each other
|
| 79 |
+
- Step 2: `h = h + gate × CrossAttention(Query=h, Key=image_hidden, Value=image_hidden)` — learnable gate controls image flow
|
| 80 |
+
- Step 3: `h = FeedForward(h)` — final transformation
|
| 81 |
+
- The **gate** is what makes BLIP special — it learns to close when generating syntax words ("the", "a") and open when generating content words ("dog", "standing")
|
| 82 |
+
|
| 83 |
+
- **GIT (Self-Attention Prefix — No Cross-Attention):**
|
| 84 |
+
- `combined_sequence = [image_patches ; text_tokens]`
|
| 85 |
+
- `output = CausalSelfAttention(combined_sequence)`
|
| 86 |
+
- Everything is one sequence — no separate image processing step
|
| 87 |
+
|
| 88 |
+
- **Custom VLM (Visual Prefix-Tuning):**
|
| 89 |
+
- Step 1: `visual_prefix = MLP(ViT_encoder(image))` — project image patches into text space
|
| 90 |
+
- Step 2: `input = [visual_prefix ; character_embeddings]` — concatenate
|
| 91 |
+
- Step 3: `output = CausalSelfAttention(input)` — process as one sequence
|
| 92 |
+
- Step 4: `logits = LanguageHead(output[after_visual_prefix:])` — predict characters
|
| 93 |
+
|
| 94 |
+
---
|
| 95 |
+
|
| 96 |
+
## 3. Dataset and Data Quality Engineering
|
| 97 |
+
|
| 98 |
+
### 3.1 The Dataset
|
| 99 |
+
|
| 100 |
+
I used the **MS-COCO Captions dataset** from HuggingFace (`whyen-wang/coco_captions`). COCO (Common Objects in Context) is the standard benchmark for image captioning — it contains natural photographs of everyday scenes, each annotated with five human-written captions describing the image.
|
| 101 |
+
|
| 102 |
+
**Why COCO?** It is the most widely used benchmark in image captioning research, which makes my results directly comparable to published papers. It also has high-quality human annotations — each image has five independent descriptions, giving multiple valid reference points for evaluation.
|
| 103 |
+
|
| 104 |
+
The data split I used:
|
| 105 |
+
|
| 106 |
+
| | Training Images | Validation Images |
|
| 107 |
+
|-|---|---|
|
| 108 |
+
| BLIP | 30,000 | 2,000 |
|
| 109 |
+
| ViT-GPT2 / GIT | 15,000 | 1,500 |
|
| 110 |
+
| Custom VLM | 15,000 | 1,500 |
|
| 111 |
+
|
| 112 |
+
BLIP gets more data because it is the largest model (224 million parameters) and benefits more from additional training examples. The smaller models converged adequately with 15,000 samples.
|
| 113 |
+
|
| 114 |
+
### 3.2 The Caption Quality Problem
|
| 115 |
+
|
| 116 |
+
One thing I noticed early on is that COCO captions are not uniformly useful for training. Some captions are extremely short — just "Dog" or "A cat" — while others are excessively long, rambling 40-word descriptions. During initial training, I found that treating every caption equally added noise: the model would sometimes learn to generate one-word descriptions, other times try to produce paragraphs.
|
| 117 |
+
|
| 118 |
+
I ran a systematic analysis on the caption word-count distribution:
|
| 119 |
+
|
| 120 |
+
| Metric | Value |
|
| 121 |
+
|---|---|
|
| 122 |
+
| Total captions sampled | 1,000 |
|
| 123 |
+
| Mean word count | 10.4 words |
|
| 124 |
+
| Range | 7 – 28 words |
|
| 125 |
+
| 10th percentile | 8 words |
|
| 126 |
+
| 50th percentile (median) | 10 words |
|
| 127 |
+
| 90th percentile | 13 words |
|
| 128 |
+
| % under 5 words | 0.0% |
|
| 129 |
+
| % over 25 words | 0.2% |
|
| 130 |
+
|
| 131 |
+
### 3.3 Caption Filtering Strategies
|
| 132 |
+
|
| 133 |
+
To address the caption quality problem, I implemented a configurable caption filtering pipeline in `data_prep.py` with five strategies:
|
| 134 |
+
|
| 135 |
+
1. **`raw`** — Pick any random caption from the five available. No filtering at all.
|
| 136 |
+
2. **`filtered`** — Only use captions between 5 and 25 words. Falls back to a random caption if none qualify. **This is the recommended default.**
|
| 137 |
+
3. **`short`** — Prefer captions with 9 or fewer words. Trains the model to be concise.
|
| 138 |
+
4. **`long`** — Prefer captions with 12 or more words. Trains the model to be descriptive.
|
| 139 |
+
5. **`mixed`** — Randomly switch between short, medium, and long strategies each time.
|
| 140 |
+
|
| 141 |
+
The filtering is implemented through the `pick_caption_by_strategy()` function, which is called during dataset construction. The strategy is configurable through `configs/base_config.py`:
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
caption_strategy: str = "filtered" # recommended default
|
| 145 |
+
caption_min_words: int = 5
|
| 146 |
+
caption_max_words: int = 25
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### 3.4 Character-Level Tokenization for the Custom VLM
|
| 150 |
+
|
| 151 |
+
Most modern language models use **subword tokenization** (called BPE — Byte Pair Encoding), where common words are single tokens and rare words are split into pieces. For example, GPT-2 treats "standing" as a single token.
|
| 152 |
+
|
| 153 |
+
My Custom VLM does something different — it uses a **character-level vocabulary of 65 characters** built from Shakespeare's complete works. This means the sentence "a man standing in front of a tree" gets encoded as individual characters: `a`, ` `, `m`, `a`, `n`, ` `, `s`, `t`, `a`, `n`, `d`, `i`, `n`, `g`... That is roughly 35 character tokens, compared to about 8 subword tokens in GPT-2.
|
| 154 |
+
|
| 155 |
+
**Why character-level?** This was a deliberate design choice — the Shakespeare decoder was built for character generation, and changing the tokenizer would require retraining from scratch. It makes the Custom VLM's job harder but also more instructive: it forces the model to learn English spelling on top of learning to describe images.
|
| 156 |
+
|
| 157 |
+
The `COCOCharDataset` class in `data_prep.py` handles this conversion, encoding each caption into a sequence of character indices and padding to `max_target_len=128`.
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## 4. Architecture Deep Dive: Four Ways to Fuse Vision and Text
|
| 162 |
+
|
| 163 |
+
### 4.1 BLIP — Gated Multimodal Mixture Attention
|
| 164 |
+
|
| 165 |
+
> **Model:** `Salesforce/blip-image-captioning-base` | **Parameters:** 224 million
|
| 166 |
+
|
| 167 |
+
BLIP's architecture is called a **Multimodal mixture of Encoder-Decoder (MED)**. The key innovation is how it injects visual information into the text decoder: between the self-attention and feed-forward sub-layers at each decoder block, there is a **cross-attention sub-layer with a learnable gate.**
|
| 168 |
+
|
| 169 |
+
**What does the gate do?** When the decoder is generating a purely syntactic token (like "the" or "is"), the gate can learn to close — effectively ignoring the image. When the decoder needs to produce a content word (like "dog" or "standing"), the gate opens to let visual features through. This selective attention prevents what researchers call "attention collapse," where the model becomes so distracted by visual features that it loses track of grammar.
|
| 170 |
+
|
| 171 |
+
In my implementation (`models/blip_tuner.py`), I load the model with **gradient checkpointing** enabled (which trades computation time for reduced memory usage — instead of keeping all intermediate values in memory for the backward pass, it recomputes them on the fly). I also resize images to 224×224 pixels to fit within Apple Silicon memory constraints.
|
| 172 |
+
|
| 173 |
+
**The `generate_with_mask()` function** is critical — it allows inference-time masking by accepting a custom attention mask that restricts which image patches the decoder can see. This is what powers the ablation experiment described in Section 7.1.
|
| 174 |
+
|
| 175 |
+
### 4.2 ViT-GPT2 — Standard Full Cross-Attention
|
| 176 |
+
|
| 177 |
+
> **Model:** `nlpconnect/vit-gpt2-image-captioning` | **Parameters:** 239 million
|
| 178 |
+
|
| 179 |
+
This is the brute-force baseline. ViT-GPT2 is a **VisionEncoderDecoderModel** that pairs:
|
| 180 |
+
- **Vision Transformer (ViT)** as the image encoder — takes a 224×224 image and splits it into a 14×14 grid of patches (196 patches + 1 special class token = 197 total), each represented as a 768-dimensional vector
|
| 181 |
+
- **GPT-2** as the text decoder — generates text one word at a time
|
| 182 |
+
|
| 183 |
+
At every decoder layer, an explicit cross-attention block lets **each text token attend to all 197 ViT patch embeddings**. Every word the model generates has full access to every part of the image at every layer.
|
| 184 |
+
|
| 185 |
+
**Advantage:** Maximum information flow — nothing is filtered or hidden.
|
| 186 |
+
**Disadvantage:** Computationally expensive, and the constant stream of visual input can sometimes confuse the language generation.
|
| 187 |
+
|
| 188 |
+
### 4.3 GIT — Zero Cross-Attention Architecture
|
| 189 |
+
|
| 190 |
+
> **Model:** `microsoft/git-base-coco` | **Parameters:** 177 million
|
| 191 |
+
|
| 192 |
+
GIT (Generative Image-to-text Transformer) represents a fundamentally different philosophy: **instead of adding cross-attention layers to connect vision and language, GIT concatenates image patch embeddings directly in front of text tokens to form a single flat sequence:**
|
| 193 |
+
|
| 194 |
+
```
|
| 195 |
+
[image_patch_1, image_patch_2, ..., image_patch_N, text_token_1, text_token_2, ...]
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
A single causal self-attention Transformer processes the entire sequence. There are no dedicated cross-attention blocks. The vision-language fusion happens implicitly through positional self-attention — text tokens at the end of the sequence naturally attend to image patches at the beginning.
|
| 199 |
+
|
| 200 |
+
**Why this is clever:** It eliminates an entire class of parameters (all the cross-attention weights), making the model smaller (177 million vs. 239 million for ViT-GPT2) and faster. The trade-off is that the model cannot separately control "how much to look at the image" versus "how much to focus on previously generated text."
|
| 201 |
+
|
| 202 |
+
**Important limitation for experiments:** Because GIT processes vision and text in a single sequence with no separate encoder, it does not have an `encoder_attention_mask` parameter. This means my masking ablation experiments (Section 7.1) cannot be applied to GIT.
|
| 203 |
+
|
| 204 |
+
### 4.4 Custom VLM — Visual Prefix-Tuning with Shakespeare Decoder
|
| 205 |
+
|
| 206 |
+
> **Parameters:** 103 million total, but only **16.2 million trainable** (the rest are frozen)
|
| 207 |
+
|
| 208 |
+
This is the model I built from scratch, and it is where most of the engineering effort went. The architecture has three components:
|
| 209 |
+
|
| 210 |
+
**Component 1: Frozen Vision Transformer (ViT) Encoder**
|
| 211 |
+
A standard ViT pre-trained on ImageNet-21K (`google/vit-base-patch16-224-in21k`). It takes a 224×224 image and produces 197 patch embeddings, each 768-dimensional. **These weights are completely frozen during training** — I do not want to disturb the image understanding capabilities that the model already learned on ImageNet.
|
| 212 |
+
|
| 213 |
+
**Component 2: Trainable MLP Bridge (The Critical Connection)**
|
| 214 |
+
This is the only component connecting vision to language. It is a small two-layer neural network (Multi-Layer Perceptron) that projects each 768-dimensional visual embedding down to the decoder's 384-dimensional embedding space:
|
| 215 |
+
|
| 216 |
+
```python
|
| 217 |
+
self.visual_projection = nn.Sequential(
|
| 218 |
+
nn.Linear(768, 1536), # expand from 768 to 1536 dimensions
|
| 219 |
+
nn.GELU(), # nonlinear activation function
|
| 220 |
+
nn.Linear(1536, 384) # compress down to 384 dimensions
|
| 221 |
+
)
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
**Why two layers instead of one?** This is explained in detail in Section 5 — a single linear layer was not enough because it cannot perform the nonlinear transformation needed to translate between visual and textual feature spaces.
|
| 225 |
+
|
| 226 |
+
**Component 3: Shakespeare-Pretrained Character-Level Decoder**
|
| 227 |
+
8 Transformer blocks, 8 attention heads, 384-dimensional embeddings, and a vocabulary of just 65 characters. This decoder was originally trained to generate Shakespeare text, character by character. During fine-tuning, both the MLP bridge and the decoder are trainable, with different learning rates.
|
| 228 |
+
|
| 229 |
+
**How the full pipeline works:**
|
| 230 |
+
1. ViT processes the image → 197 patches × 768 dimensions
|
| 231 |
+
2. MLP projects each patch → 197 patches × 384 dimensions (these become the "visual prefix")
|
| 232 |
+
3. Character embeddings for the caption text are looked up → T characters × 384 dimensions
|
| 233 |
+
4. Visual prefix and character embeddings are concatenated into one sequence
|
| 234 |
+
5. A causal self-attention mask is applied, and the full Transformer decoder processes the sequence
|
| 235 |
+
6. The language model head produces logits (predictions) only for the text portion (positions after the visual prefix)
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## 5. Building a Custom Vision-Language Model from Scratch — The Full Story
|
| 240 |
+
|
| 241 |
+
This section tells the complete narrative of building the Custom VLM, including every bug, every failed experiment, and every fix. **This was the most educational part of the entire project,** and it demonstrates the kind of debugging that real machine learning engineering requires.
|
| 242 |
+
|
| 243 |
+
### 5.1 The Starting Point: A Shakespeare Decoder
|
| 244 |
+
|
| 245 |
+
The journey started with a character-level Transformer I had previously trained on the complete works of Shakespeare (~1 MB of Elizabethan English). This model could generate passable Shakespeare prose — things like "To be or not to be, that is the question" continuations. It had 8 Transformer blocks, 8 attention heads, 384-dimensional embeddings, and a 65-character vocabulary.
|
| 246 |
+
|
| 247 |
+
The idea was simple: if this decoder already understands English (even old English), maybe I could teach it to describe images by just showing it visual features as a prefix. I would freeze the ViT, freeze the Shakespeare decoder, and **only train a small projection layer** to translate from ViT's 768-dimensional visual space to the decoder's 384-dimensional text space.
|
| 248 |
+
|
| 249 |
+
This approach is called **"visual prefix-tuning"** and it is conceptually similar to what LLaVA (Large Language and Vision Assistant) does, except LLaVA uses GPT-4-level decoders and I am using a tiny character-level model.
|
| 250 |
+
|
| 251 |
+
### 5.2 Stage 1: The Linear Projection Bottleneck (Training Loss Stuck at 2.92)
|
| 252 |
+
|
| 253 |
+
My first implementation used a single linear layer for the projection:
|
| 254 |
+
|
| 255 |
+
```python
|
| 256 |
+
# Original (broken) — just one matrix multiplication
|
| 257 |
+
self.visual_projection = nn.Linear(768, 384)
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
I trained this for 15 epochs and watched the training loss. It dropped quickly at first — from around 4.5 down to about 3.5 — but then hit a rigid plateau at approximately **2.922** and refused to budge. Epoch after epoch, the loss hovered around 2.92, never improving.
|
| 261 |
+
|
| 262 |
+
The generated text was complete gibberish: strings like `"iGiiiiiGiviqiGqiFliqiGidlidiliGilFGilqiiiqiiiiGii"`. The CIDEr score was **0.0000** — literally zero. Not a single word overlapped with any human reference caption.
|
| 263 |
+
|
| 264 |
+
> **Why this happened:** A single linear projection is just a matrix multiplication — it can rotate and scale the visual embeddings, but it cannot perform the kind of nonlinear transformation needed to translate between two fundamentally different feature spaces. ViT's 768-dimensional space encodes visual concepts (edges, textures, object boundaries), while the decoder's 384-dimensional space encodes character-level language patterns. Mapping between these with just a matrix multiply is like trying to translate French to Chinese using only a ruler — the tool simply lacks the expressive power.
|
| 265 |
+
|
| 266 |
+
### 5.3 Stage 1 Fix: Upgrading to a Two-Layer MLP (Inspired by LLaVA)
|
| 267 |
+
|
| 268 |
+
I replaced the single linear layer with a two-layer MLP (Multi-Layer Perceptron):
|
| 269 |
+
|
| 270 |
+
```python
|
| 271 |
+
# Fixed — two layers with GELU nonlinearity
|
| 272 |
+
self.visual_projection = nn.Sequential(
|
| 273 |
+
nn.Linear(768, 1536), # 768 → 1536 (expand to give room for learning)
|
| 274 |
+
nn.GELU(), # nonlinear activation function
|
| 275 |
+
nn.Linear(1536, 384) # 1536 → 384 (compress to decoder's dimension)
|
| 276 |
+
)
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
**What is GELU?** GELU (Gaussian Error Linear Unit) is an activation function — a mathematical function that introduces nonlinearity. Without it, stacking two linear layers is mathematically equivalent to a single linear layer. The GELU between the two layers gives the projection the ability to learn nonlinear boundaries — meaning it can map visual concepts to text concepts in ways that a simple scaling/rotation cannot.
|
| 280 |
+
|
| 281 |
+
**Why 1536 as the middle dimension?** This is 2× the input dimension (768), providing a wide intermediate representation where the model can "reason" about how visual concepts map to textual concepts before compressing down to 384. This is the same approach used by LLaVA.
|
| 282 |
+
|
| 283 |
+
### 5.4 Stage 2: Why Training Loss Alone Is Not Enough
|
| 284 |
+
|
| 285 |
+
Even after the MLP upgrade, I realized I had a **measurement problem**. The training loss was going down, but I had no way to know if the actual captions were any good.
|
| 286 |
+
|
| 287 |
+
**What is training loss?** Training loss (specifically, cross-entropy loss) measures the probability the model assigns to the correct next token given all previous tokens. It is a mathematical surrogate — a number the optimizer tries to minimize — but it does not directly measure caption quality. A model can achieve low cross-entropy loss while generating grammatically incorrect, semantically meaningless text.
|
| 288 |
+
|
| 289 |
+
**What is CIDEr?** CIDEr (Consensus-based Image Description Evaluation) is a metric specifically designed for image captioning. It compares the caption our model generates to five human-written descriptions of the same image using n-gram overlap (matching sequences of consecutive words), weighted by TF-IDF (a technique that gives more weight to descriptive words like "bicycle" and less weight to common words like "the"). **A higher CIDEr score means the generated caption sounds more like what a human would write.**
|
| 290 |
+
|
| 291 |
+
| Metric | What It Measures | Reliable? |
|
| 292 |
+
|---|---|---|
|
| 293 |
+
| Training Loss | How well model predicts next token on training data | ❌ Can be misleading — low loss ≠ good captions |
|
| 294 |
+
| Validation Loss | How well model predicts next token on unseen data | ⚠️ Better, but still a surrogate |
|
| 295 |
+
| **CIDEr Score** | **How closely generated captions match human descriptions** | **✅ The gold standard for captioning** |
|
| 296 |
+
|
| 297 |
+
**The pipeline changes I made to `train.py`:**
|
| 298 |
+
|
| 299 |
+
1. **Validation loss tracking** — At the end of every epoch, run a forward pass on a validation subset to detect overfitting (when training loss drops but validation loss rises, the model is memorizing training data instead of learning general patterns).
|
| 300 |
+
|
| 301 |
+
2. **Live CIDEr computation** — Actually generate captions using beam search on the validation set, then score them with the `pycocoevalcap` CIDEr scorer. This tells me if the model is producing good English descriptions, not just achieving low loss numbers.
|
| 302 |
+
|
| 303 |
+
3. **CIDEr-based checkpointing** — Save the `best/` checkpoint based on the **highest validation CIDEr**, not the lowest training loss. This ensures the saved model is the one that actually produces the best captions.
|
| 304 |
+
|
| 305 |
+
The epoch-end logging now shows all three metrics:
|
| 306 |
+
```
|
| 307 |
+
Epoch 11/15 avg loss (Train): 0.8573
|
| 308 |
+
Running Validation (Loss & CIDEr)...
|
| 309 |
+
Validation Loss: 0.8077
|
| 310 |
+
Validation CIDEr: 0.2863
|
| 311 |
+
🏆 New best CIDEr! Saved → ./outputs/custom_vlm/best
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
### 5.5 Stage 3: The Gibberish Mystery — 337 Out of 342 Weights Silently Failed to Load
|
| 315 |
+
|
| 316 |
+
This was the most painful and instructive bug of the entire project. Even with the MLP upgrade and CIDEr pipeline in place, the model was **still generating pure gibberish**. I could see the loss was dropping, the pipeline was working, but the outputs were nonsensical character sequences.
|
| 317 |
+
|
| 318 |
+
After day of investigation, I found the root cause: **an architecture mismatch between the Shakespeare checkpoint and the Custom VLM decoder.**
|
| 319 |
+
|
| 320 |
+
Here is what happened:
|
| 321 |
+
|
| 322 |
+
**The original Shakespeare model** was built with a custom per-head attention implementation. Each of its 8 attention heads had its own separate weight matrices:
|
| 323 |
+
|
| 324 |
+
```
|
| 325 |
+
blocks.0.sa_head.heads.0.key.weight → shape (48, 384) ← head 1
|
| 326 |
+
blocks.0.sa_head.heads.1.key.weight → shape (48, 384) ← head 2
|
| 327 |
+
blocks.0.sa_head.heads.2.key.weight → shape (48, 384) ← head 3
|
| 328 |
+
... (8 separate weight matrices per layer)
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
**But the Custom VLM decoder** used PyTorch's built-in `nn.TransformerEncoder`, which expects **fused** (combined) attention weights:
|
| 332 |
+
|
| 333 |
+
```
|
| 334 |
+
decoder_blocks.layers.0.self_attn.in_proj_weight → shape (1152, 384)
|
| 335 |
+
```
|
| 336 |
+
|
| 337 |
+
**These are completely different formats.** The per-head format has 8 separate small matrices. PyTorch's format concatenates all heads into a single large matrix. It is like trying to load 8 individual photos into a slot designed for one panoramic image.
|
| 338 |
+
|
| 339 |
+
To make matters worse, the original Custom VLM config used **6 blocks, 6 heads, and a block size of 512**, while the Shakespeare checkpoint had **8 blocks, 8 heads, and a block size of 256**. **Nothing matched.**
|
| 340 |
+
|
| 341 |
+
When I loaded the checkpoint with `strict=False`:
|
| 342 |
+
|
| 343 |
+
```python
|
| 344 |
+
model.load_state_dict(checkpoint, strict=False)
|
| 345 |
+
```
|
| 346 |
+
|
| 347 |
+
PyTorch silently compared the key names, found that almost none of them matched, and simply **skipped 337 out of 342 tensors**. Only 5 tensors loaded — the character embedding table and the language model head. **The entire decoder brain — all the self-attention layers and feed-forward networks — was left randomly initialized.**
|
| 348 |
+
|
| 349 |
+
And because `freeze_decoder()` was called immediately after loading, those random weights were frozen in place. The model was literally running on random noise, with no way to learn.
|
| 350 |
+
|
| 351 |
+
> **⚠️ This is why `strict=False` is dangerous.** PyTorch does not raise an error or even a warning when the vast majority of a model fails to load. It just silently skips mismatched keys, leaving the developer to discover the problem through painstaking debugging. **In production code, always check how many tensors actually loaded.**
|
| 352 |
+
|
| 353 |
+
### 5.6 Stage 3 Fix: Architecture Alignment + Weight Remapping + Decoder Unfreezing
|
| 354 |
+
|
| 355 |
+
The fix required three coordinated changes:
|
| 356 |
+
|
| 357 |
+
**Fix 1: Architecture Alignment**
|
| 358 |
+
I updated `custom_vlm_config.py` to exactly match the Shakespeare checkpoint dimensions:
|
| 359 |
+
|
| 360 |
+
```python
|
| 361 |
+
text_embed_dim: int = 384 # match Shakespeare (was different before)
|
| 362 |
+
n_heads: int = 8 # was 6, now 8 to match Shakespeare
|
| 363 |
+
n_layers: int = 8 # was 6, now 8 to match Shakespeare
|
| 364 |
+
block_size: int = 256 # was 512, now 256 to match Shakespeare
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
**Fix 2: Weight Remapping**
|
| 368 |
+
I completely rewrote the `load_shakespeare_weights()` method in `custom_vlm.py`. The new implementation reads each per-head weight from the Shakespeare checkpoint, concatenates the 8 head weights for Query, Key, and Value into a single fused matrix, and maps it to PyTorch's expected format:
|
| 369 |
+
|
| 370 |
+
```python
|
| 371 |
+
# For each Transformer layer, fuse 8 per-head (48, 384) weights
|
| 372 |
+
# into one (1152, 384) matrix that PyTorch expects
|
| 373 |
+
query_weights = []
|
| 374 |
+
key_weights = []
|
| 375 |
+
value_weights = []
|
| 376 |
+
for head_idx in range(8):
|
| 377 |
+
query_weights.append(ckpt[f"blocks.{layer}.sa_head.heads.{head_idx}.query.weight"])
|
| 378 |
+
key_weights.append(ckpt[f"blocks.{layer}.sa_head.heads.{head_idx}.key.weight"])
|
| 379 |
+
value_weights.append(ckpt[f"blocks.{layer}.sa_head.heads.{head_idx}.value.weight"])
|
| 380 |
+
|
| 381 |
+
in_proj_weight = torch.cat(query_weights + key_weights + value_weights, dim=0)
|
| 382 |
+
# Result: (1152, 384) = (3 attention_types × 8 heads × 48 dim_per_head, 384)
|
| 383 |
+
```
|
| 384 |
+
|
| 385 |
+
After loading, the method prints a verification count: **"96 of 96 decoder tensors loaded."** — all weights accounted for.
|
| 386 |
+
|
| 387 |
+
**Fix 3: Decoder Unfreezing with Discriminative Learning Rates**
|
| 388 |
+
Instead of freezing the decoder, I unfroze it and used **discriminative learning rates** — different learning speeds for different parts of the model:
|
| 389 |
+
|
| 390 |
+
- **Projection MLP:** Learning rate = `1e-4` (0.0001) — aggressive updates because this is randomly initialized and needs to learn the vision-to-text mapping from zero
|
| 391 |
+
- **Decoder:** Learning rate = `5e-5` (0.00005) — gentle updates because the Shakespeare weights are a good starting point and we just want to slowly adapt from Elizabethan English to modern captioning style
|
| 392 |
+
|
| 393 |
+
### 5.7 The Results: From Gibberish to English
|
| 394 |
+
|
| 395 |
+
**The difference was immediate and dramatic:**
|
| 396 |
+
|
| 397 |
+
| Metric | ❌ Before (Broken) | ✅ After (Fixed) |
|
| 398 |
+
|---|---|---|
|
| 399 |
+
| Decoder tensors loaded | 5 of 342 (1.4%) | **96 of 96 (100%)** |
|
| 400 |
+
| Trainable parameters | 2.4 million (projection only) | **16.2 million (projection + decoder)** |
|
| 401 |
+
| Best training loss | 2.9226 (stuck at plateau) | **0.8446** |
|
| 402 |
+
| Best validation loss | Not tracked | **0.7930** |
|
| 403 |
+
| **Best CIDEr score** | **0.0000** | **0.2863** |
|
| 404 |
+
| Generated text sample | `"iGiiiiiGiviqiGqiFl..."` | `"man in the bluess and white play with and a pizza"` |
|
| 405 |
+
|
| 406 |
+
### Epoch-by-Epoch Progression (Custom VLM Training After Fix)
|
| 407 |
+
|
| 408 |
+
This table shows how the Custom VLM improved over 15 epochs. **This is the key evidence that the fixes worked:**
|
| 409 |
+
|
| 410 |
+
| Epoch | Training Loss | Validation Loss | CIDEr Score | What Happened |
|
| 411 |
+
|---|---|---|---|---|
|
| 412 |
+
| 1 | 1.9234 | 1.1396 | 0.0577 | Immediately broke the 2.92 plateau |
|
| 413 |
+
| 2 | 1.2543 | 0.9671 | 0.1352 | CIDEr doubled — real words emerging |
|
| 414 |
+
| 3 | 1.1261 | 0.9253 | 0.1594 | Sentences forming |
|
| 415 |
+
| 6 | 0.9339 | 0.8627 | 0.2329 | Clear English captions |
|
| 416 |
+
| 8 | 0.8919 | 0.8530 | 0.2391 | Steady gains |
|
| 417 |
+
| 10 | 0.8715 | 0.8501 | 0.2598 | Continued improvement |
|
| 418 |
+
| **11** | **0.8573** | **0.8077** | **0.2863** | **🏆 Best CIDEr — saved as best checkpoint** |
|
| 419 |
+
| 12 | 0.8514 | 0.7973 | 0.2728 | CIDEr starts dipping (overfitting) |
|
| 420 |
+
| 15 | 0.8446 | 0.8055 | 0.2284 | Slight overfitting — CIDEr drops further |
|
| 421 |
+
|
| 422 |
+
**Key observations from this progression:**
|
| 423 |
+
|
| 424 |
+
1. **The loss plateau at 2.92 broke immediately** on epoch 1 once the decoder had properly loaded weights. This confirms the plateau was caused by the architecture mismatch, not a fundamental capacity limitation.
|
| 425 |
+
|
| 426 |
+
2. **CIDEr peaked at epoch 11 (0.2863) and then started declining** even though training loss continued to drop. This is classic **overfitting** — the model memorizes training examples instead of generalizing. This validates the decision to checkpoint based on CIDEr rather than loss.
|
| 427 |
+
|
| 428 |
+
3. **The best validation loss (0.7930 at epoch 14) and the best CIDEr (0.2863 at epoch 11) occurred at different epochs.** This proves that loss and caption quality are genuinely different things — lowest loss ≠ best captions.
|
| 429 |
+
|
| 430 |
+
---
|
| 431 |
+
|
| 432 |
+
## 6. Training Pipeline: Making It All Work
|
| 433 |
+
|
| 434 |
+
### 6.1 The Unified Training Script
|
| 435 |
+
|
| 436 |
+
All four architectures are trained through a single entry point: `python train.py --model {blip|vit_gpt2|git|custom}`. The script handles model selection, configuration loading, and device detection (MPS → CUDA → CPU) automatically.
|
| 437 |
+
|
| 438 |
+
### 6.2 Hyperparameters
|
| 439 |
+
|
| 440 |
+
| Parameter | BLIP | ViT-GPT2 | GIT | Custom VLM |
|
| 441 |
+
|---|---|---|---|---|
|
| 442 |
+
| Epochs | 3 | 3 | 3 | 15 |
|
| 443 |
+
| Learning Rate | 1e-5 | 2e-5 | 2e-5 | 1e-4 (projection) / 5e-5 (decoder) |
|
| 444 |
+
| Batch Size | 16 | 8 | 8 | 16 |
|
| 445 |
+
| Max Target Length | 32 tokens | 32 tokens | 32 tokens | 128 characters |
|
| 446 |
+
| Gradient Accumulation Steps | 4 | 4 | 4 | 4 |
|
| 447 |
+
| Warmup Ratio | 0.03 (3%) | 0.03 | 0.03 | 0.03 |
|
| 448 |
+
| Weight Decay | 0.01 | 0.01 | 0.01 | 0.01 |
|
| 449 |
+
| Optimizer | AdamW | AdamW | AdamW | AdamW |
|
| 450 |
+
| Learning Rate Schedule | Cosine with warmup | Cosine with warmup | Cosine with warmup | Cosine with warmup |
|
| 451 |
+
|
| 452 |
+
**Why these choices:**
|
| 453 |
+
|
| 454 |
+
- **BLIP gets a lower learning rate (1e-5)** because it is the largest and most sensitive to destabilization. The pre-trained HuggingFace models have already converged; aggressive updates would break their learned representations.
|
| 455 |
+
- **The Custom VLM gets 15 epochs** because the character-level decoder takes longer to converge — it needs to learn character-by-character spelling in addition to visual grounding. The other models produce subword tokens and need far fewer iterations.
|
| 456 |
+
- **Gradient accumulation of 4 with batch size 16** gives an effective batch size of 64. This smooths out gradient noise without requiring Apple Silicon to hold 64 images in memory at once.
|
| 457 |
+
|
| 458 |
+
### 6.3 Efficiency Optimizations
|
| 459 |
+
|
| 460 |
+
- **Gradient checkpointing** — Enabled for BLIP. Instead of storing all intermediate values in memory for the backward pass (backpropagation), the model recomputes them on the fly. This roughly halves memory usage at the cost of ~30% slower training. Essential for fitting the 224-million-parameter BLIP on consumer hardware.
|
| 461 |
+
|
| 462 |
+
- **MPS (Metal Performance Shaders) acceleration** — All models run on Apple Silicon's GPU. This required setting `num_workers=0` in the data loader (MPS does not support multiprocessing data loading) and capping images at 224×224 pixels.
|
| 463 |
+
|
| 464 |
+
- **Gradient norm clipping** — Gradients are clipped to a norm of 1.0 to prevent exploding gradients. This is particularly important during early training epochs when the Custom VLM's projection layer is learning from scratch and can produce very large gradient values.
|
| 465 |
+
|
| 466 |
+
- **Cosine learning rate scheduling with warmup** — The learning rate starts at zero, linearly warms up during the first 3% of training steps, then follows a cosine curve back down to near-zero. This gives the model time to find a good optimization direction before committing to steep gradients.
|
| 467 |
+
|
| 468 |
+
### 6.4 Checkpoint Management
|
| 469 |
+
|
| 470 |
+
Checkpoints are saved to two locations:
|
| 471 |
+
|
| 472 |
+
| Directory | What It Contains | When to Use |
|
| 473 |
+
|---|---|---|
|
| 474 |
+
| `outputs/{model}/best/` | Checkpoint with the **highest validation CIDEr** seen during training | ✅ Use for evaluation and deployment |
|
| 475 |
+
| `outputs/{model}/latest/` | Checkpoint from the most recent epoch | 🔧 Use for debugging or resuming training |
|
| 476 |
+
|
| 477 |
+
---
|
| 478 |
+
|
| 479 |
+
## 7. Experiments and Results
|
| 480 |
+
|
| 481 |
+
### 7.1 Experiment 1: Cross-Attention Masking — What Happens When We Hide Parts of the Image?
|
| 482 |
+
|
| 483 |
+
**Question:** How important is fine-grained spatial visual information for caption generation? Can we remove parts of the image and still get good captions?
|
| 484 |
+
|
| 485 |
+
I designed four masking modes that manipulate which image patches the decoder can "see" during inference (caption generation):
|
| 486 |
+
|
| 487 |
+
**Mode 1 — Baseline (Full Attention)**
|
| 488 |
+
All 197 patches (1 class token + 196 spatial patches from the 14×14 grid) are visible. This is the upper-bound reference — the model sees the entire image.
|
| 489 |
+
|
| 490 |
+
**Mode 2 — Random Patch Dropout (50%)**
|
| 491 |
+
Randomly hide 50% of the 196 spatial patches; the class token always stays visible. Does the model still generate good captions with half the image hidden?
|
| 492 |
+
|
| 493 |
+
**Mode 3 — Center-Focus (Keep Only Inner 8×8 Grid)**
|
| 494 |
+
Only keep the inner 64 patches of the 14×14 spatial grid, dropping the entire outer ring (the background and periphery). Does removing the edges and background matter?
|
| 495 |
+
|
| 496 |
+
**Mode 4 — Squint (Compress Everything to One Token)**
|
| 497 |
+
Average all 196 spatial patches into a single global summary token. The mask becomes just 2 tokens: the class token and this one average. Can the model work with an extremely compressed representation?
|
| 498 |
+
|
| 499 |
+
**Results (BLIP, base pre-trained weights, 25 evaluation batches):**
|
| 500 |
+
|
| 501 |
+
| Mode | CIDEr Score | Change from Baseline | Interpretation |
|
| 502 |
+
|---|---|---|---|
|
| 503 |
+
| ✅ Baseline | **0.5371** | — | Full information reference |
|
| 504 |
+
| 🎲 Random Dropout (50%) | **0.5371** | +0.0000 (zero change!) | **Massive spatial redundancy — half the patches are disposable** |
|
| 505 |
+
| 🎯 Center-Focus (8×8) | **0.5371** | +0.0000 (zero change!) | **Background and edges contribute nothing** |
|
| 506 |
+
| 👀 Squint (Global Pool) | **0.0008** | −0.5363 (99.8% drop) | **Catastrophic failure — local details are essential** |
|
| 507 |
+
|
| 508 |
+
**What do these results mean?**
|
| 509 |
+
|
| 510 |
+
These results reveal something fascinating about how vision models process images:
|
| 511 |
+
|
| 512 |
+
- **Random dropout and center-focus cause zero degradation.** This means that for standard captioning, roughly **half of all spatial patches are entirely redundant**. The model can generate equally good captions with only 98 patches as with all 196. Background patches (the outer ring) also contribute nothing measurable.
|
| 513 |
+
|
| 514 |
+
- **But squinting destroys performance completely.** When you compress all 196 patches into a single average vector, CIDEr drops to essentially zero. This proves that while many individual patches are redundant, their collective **spatial arrangement** carries critical information. A single global vector cannot capture object locations, spatial relationships, and scene layout.
|
| 515 |
+
|
| 516 |
+
> **The takeaway:** BLIP's cross-attention is extremely robust to significant patch dropout, but it fundamentally requires spatially-distributed features. The spatial structure of the image matters more than the quantity of patches.
|
| 517 |
+
|
| 518 |
+
### 7.2 Experiment 2: Decoding Parameter Sweep — Finding the Best Caption Generation Settings
|
| 519 |
+
|
| 520 |
+
**Question:** How do beam search settings affect caption quality?
|
| 521 |
+
|
| 522 |
+
**What is beam search?** When a model generates text, it does not just pick the most probable next word at each step (that is called "greedy search" and often produces mediocre results). Instead, beam search maintains multiple candidate sentences simultaneously and picks the one with the best overall probability. Beam width controls how many candidates to track — more beams means more exploration but slower generation.
|
| 523 |
+
|
| 524 |
+
I swept across three decoding parameters for BLIP:
|
| 525 |
+
- **Beam sizes:** 3, 5, 10 (how many candidate sentences to track)
|
| 526 |
+
- **Length penalties:** 0.8, 1.0, 1.2 (penalty < 1.0 encourages longer captions, > 1.0 encourages shorter)
|
| 527 |
+
- **Max new tokens:** 20, 50 (maximum caption length allowed)
|
| 528 |
+
|
| 529 |
+
This produced **18 configurations** (3 × 3 × 2). Here are the results ranked by CIDEr score:
|
| 530 |
+
|
| 531 |
+
| Beams | Length Penalty | Max Tokens | CIDEr Score |
|
| 532 |
+
|---|---|---|---|
|
| 533 |
+
| 10 | 1.2 | 50 | **0.6199** ← 🏆 best |
|
| 534 |
+
| 10 | 1.0 | 20 | 0.5904 |
|
| 535 |
+
| 5 | 1.0 | 20 | 0.5896 |
|
| 536 |
+
| 10 | 1.2 | 20 | 0.5785 |
|
| 537 |
+
| 10 | 0.8 | 50 | 0.5722 |
|
| 538 |
+
| 3 | 1.2 | 20 | 0.5653 |
|
| 539 |
+
| 5 | 1.0 | 50 | 0.5598 |
|
| 540 |
+
| 5 | 1.2 | 20 | 0.5533 |
|
| 541 |
+
| 10 | 1.0 | 50 | 0.5457 |
|
| 542 |
+
| 3 | 1.2 | 50 | 0.5456 |
|
| 543 |
+
| 3 | 1.0 | 20 | 0.5451 |
|
| 544 |
+
| 10 | 0.8 | 20 | 0.5321 |
|
| 545 |
+
| 3 | 1.0 | 50 | 0.5262 |
|
| 546 |
+
| 5 | 1.2 | 50 | 0.5106 |
|
| 547 |
+
| 5 | 0.8 | 20 | 0.5046 |
|
| 548 |
+
| 3 | 0.8 | 50 | 0.5031 |
|
| 549 |
+
| 5 | 0.8 | 50 | 0.4914 |
|
| 550 |
+
| 3 | 0.8 | 20 | 0.4783 |
|
| 551 |
+
|
| 552 |
+
**Key findings:**
|
| 553 |
+
|
| 554 |
+
- **Beam size is the most impactful parameter.** Going from 3 beams to 10 beams with the best other settings improves CIDEr from ~0.55 to ~0.62 — an approximate **13% improvement**. More candidate sentences means better final selection.
|
| 555 |
+
- **Slight preference for shorter captions helps (length penalty 1.2).** BLIP tends to "ramble" with longer generation budgets, and concise captions match human references better.
|
| 556 |
+
- **The best combination is: beam_size=10, length_penalty=1.2, max_tokens=50** — yielding a CIDEr of **0.6199**.
|
| 557 |
+
|
| 558 |
+
### 7.3 Experiment 3: Caption Quality Filtering — Does Training Data Quality Matter?
|
| 559 |
+
|
| 560 |
+
**Question:** Does filtering caption quality before training improve model performance?
|
| 561 |
+
|
| 562 |
+
I evaluated BLIP under four caption selection strategies (what kind of captions we feed the model during training):
|
| 563 |
+
|
| 564 |
+
| Strategy | CIDEr Score | Change from Raw | Interpretation |
|
| 565 |
+
|---|---|---|---|
|
| 566 |
+
| raw (no filtering) | **0.6359** | — | **Best for this clean dataset** |
|
| 567 |
+
| short (≤ 9 words) | 0.6016 | −0.0342 | Too brief for good word overlap |
|
| 568 |
+
| filtered (5–25 words) | 0.5877 | −0.0481 | Quality filter |
|
| 569 |
+
| long (≥ 12 words) | 0.5389 | −0.0970 | Too verbose for base model |
|
| 570 |
+
|
| 571 |
+
**Why did raw perform best?** The COCO dataset is already relatively clean (mean word count 10.4, only 0.2% of captions over 25 words), so filtering actually removes useful variety. However, the **filtered strategy is still recommended as a general default** because it protects against noisy outliers in less curated datasets and ensures reproducible, consistent training behavior.
|
| 572 |
+
|
| 573 |
+
---
|
| 574 |
+
|
| 575 |
+
## 8. The Streamlit Application
|
| 576 |
+
|
| 577 |
+
The interactive demo is implemented in `app.py` and provides a complete interface for exploring and comparing all four architectures.
|
| 578 |
+
|
| 579 |
+
### 8.1 Features
|
| 580 |
+
|
| 581 |
+
| Feature | What It Does |
|
| 582 |
+
|---|---|
|
| 583 |
+
| **Caption Tab** | Upload an image, select a model and generation mode, generate a caption |
|
| 584 |
+
| **Compare All Models Tab** | Run all 4 architectures side-by-side on the same image with a summary table |
|
| 585 |
+
| **Experiment Results Tab** | View pre-computed results from all three experiments |
|
| 586 |
+
| **Weight Source Selector** | Switch between base (pre-trained), fine-tuned (best CIDEr), and fine-tuned (latest) weights |
|
| 587 |
+
| **Advanced Controls** | Adjust beam width, temperature, length penalty, top-k, and top-p |
|
| 588 |
+
| **Toxicity Filter** | Every caption is checked through `unitary/toxic-bert` before display |
|
| 589 |
+
|
| 590 |
+
### 8.2 Architecture Info Cards
|
| 591 |
+
|
| 592 |
+
Each model gets a descriptive card in the sidebar explaining its cross-attention approach in plain language:
|
| 593 |
+
|
| 594 |
+
- **BLIP:** "Gated cross-attention is injected between self-attention and feed-forward layers in the decoder, allowing fine-grained visual feature querying at each decoding step."
|
| 595 |
+
- **ViT-GPT2:** "Every GPT-2 text token attends to all 197 ViT patch embeddings via full cross-attention at every decoder layer."
|
| 596 |
+
- **GIT:** "Image patches are concatenated to the front of the token sequence; causal self-attention handles everything in one flat joint sequence."
|
| 597 |
+
- **Custom VLM:** "Fuses a frozen ViT with a Shakespeare character-level decoder via a trainable projection."
|
| 598 |
+
|
| 599 |
+
### 8.3 Safety: Toxicity Filtering
|
| 600 |
+
|
| 601 |
+
Because captioning models can occasionally generate offensive descriptions (particularly on ambiguous or culturally sensitive images), every generated caption passes through the `detoxify` library's `unitary/toxic-bert` model before being displayed. If the toxicity score exceeds a threshold, the caption is redacted and the user is warned.
|
| 602 |
+
|
| 603 |
+
---
|
| 604 |
+
|
| 605 |
+
## 9. Key Insights and Analytical Conclusions
|
| 606 |
+
|
| 607 |
+
### 9.1 Cross-Attention Is Helpful but Not Mandatory
|
| 608 |
+
|
| 609 |
+
GIT achieves strong captioning performance using only prefix self-attention — **no dedicated cross-attention blocks at all**. This proves that cross-attention, while helpful for selective visual querying, is not strictly mandatory for multimodal fusion. The prefix concatenation approach works because self-attention is a universal mechanism: as long as visual and text tokens share the same sequence, the model learns to route information between modalities.
|
| 610 |
+
|
| 611 |
+
### 9.2 Gated Attention Gives the Best Trade-Off
|
| 612 |
+
|
| 613 |
+
**BLIP's gated cross-attention achieves the highest CIDEr scores** because the gate selectively filters visual information. When generating syntax words ("the," "a"), the gate closes and the model relies on its language model. When generating content words ("dog," "bicycle"), the gate opens and visual features flow through. This prevents attention collapse — a failure mode where too much visual information disrupts language coherence.
|
| 614 |
+
|
| 615 |
+
### 9.3 Images Contain Massive Spatial Redundancy
|
| 616 |
+
|
| 617 |
+
The masking experiment proves that **50% of image patches can be removed with zero quality loss**, and cropping to the center removes the entire background with no effect. But compressing to a single global vector destroys performance. This means: **spatial structure matters more than absolute patch count.**
|
| 618 |
+
|
| 619 |
+
### 9.4 Loss and Quality Are Different Things
|
| 620 |
+
|
| 621 |
+
The Custom VLM training showed that **the best training loss and the best CIDEr occurred at different epochs** (epoch 14 vs. epoch 11). A model that predicts the next token well (low loss) is not necessarily a model that produces captions humans would agree with (high CIDEr). **Always evaluate with task-specific metrics, not just loss.**
|
| 622 |
+
|
| 623 |
+
### 9.5 Silent Failures Are the Worst Kind of Bug
|
| 624 |
+
|
| 625 |
+
The most time-consuming problem in this project was a weight-loading failure that produced **no error message, no warning, and no indication** that 98.5% of the model failed to load. **In production machine learning code, always verify how many tensors actually loaded when using `strict=False`.**
|
| 626 |
+
|
| 627 |
+
---
|
| 628 |
+
|
| 629 |
+
## 10. Future Improvements
|
| 630 |
+
|
| 631 |
+
The Custom VLM currently achieves a best CIDEr of **0.2863**. Here is a roadmap of improvements ordered by expected impact:
|
| 632 |
+
|
| 633 |
+
### High Impact (Could Improve CIDEr by +0.15 to +0.40 Each)
|
| 634 |
+
|
| 635 |
+
| Improvement | What It Changes | Expected CIDEr Gain |
|
| 636 |
+
|---|---|---|
|
| 637 |
+
| **Switch from characters to subword tokens** | "standing" becomes 1 token instead of 8 characters | +0.15 to +0.30 |
|
| 638 |
+
| **Replace Shakespeare decoder with GPT-2 Small** | GPT-2 already knows modern English; Shakespeare decoder had to learn both English and captioning | +0.20 to +0.40 |
|
| 639 |
+
| **Increase training data (15K → 80K)** | Use the full COCO training set instead of 18% | +0.05 to +0.10 |
|
| 640 |
+
|
| 641 |
+
### Medium Impact (Could Improve CIDEr by +0.05 to +0.15 Each)
|
| 642 |
+
|
| 643 |
+
| Improvement | What It Changes |
|
| 644 |
+
|---|---|
|
| 645 |
+
| **Label smoothing** (0.1) | Prevents overconfident character predictions |
|
| 646 |
+
| **Multi-reference CIDEr** (use all 5 human captions) | More accurate quality measurement |
|
| 647 |
+
| **Proper cross-attention layers** in the decoder | Dedicated vision-text interaction instead of prefix concatenation |
|
| 648 |
+
| **Stronger vision encoder** (CLIP ViT-Large) | CLIP features are inherently aligned with text |
|
| 649 |
+
|
| 650 |
+
---
|
| 651 |
+
|
| 652 |
+
## 11. Reproducibility and Commands
|
| 653 |
+
|
| 654 |
+
### Environment Setup
|
| 655 |
+
|
| 656 |
+
```bash
|
| 657 |
+
python -m venv venv
|
| 658 |
+
source venv/bin/activate
|
| 659 |
+
pip install -r requirements.txt
|
| 660 |
+
|
| 661 |
+
# Verify acceleration is available (Apple Silicon)
|
| 662 |
+
python -c "import torch; print(torch.backends.mps.is_available())"
|
| 663 |
+
```
|
| 664 |
+
|
| 665 |
+
### Training
|
| 666 |
+
|
| 667 |
+
```bash
|
| 668 |
+
python train.py --model blip # ~1.5 hours on Apple Silicon
|
| 669 |
+
python train.py --model vit_gpt2 # ~1 hour
|
| 670 |
+
python train.py --model git # ~20 minutes
|
| 671 |
+
python train.py --model custom # ~3 hours (15 epochs)
|
| 672 |
+
```
|
| 673 |
+
|
| 674 |
+
### Evaluation
|
| 675 |
+
|
| 676 |
+
```bash
|
| 677 |
+
# Evaluate one model
|
| 678 |
+
python eval.py --model blip --weights best
|
| 679 |
+
|
| 680 |
+
# Compare all models
|
| 681 |
+
python eval.py --model all --weights best
|
| 682 |
+
|
| 683 |
+
# Run cross-attention masking experiment
|
| 684 |
+
python eval.py --model blip --ablation --weights best
|
| 685 |
+
|
| 686 |
+
# Run decoding parameter sweep
|
| 687 |
+
python eval.py --model blip --sweep --weights best
|
| 688 |
+
|
| 689 |
+
# Custom decoding settings
|
| 690 |
+
python eval.py --model blip --weights best --num_beams 10 --max_new_tokens 50 --length_penalty 1.2
|
| 691 |
+
```
|
| 692 |
+
|
| 693 |
+
### Streamlit Demo
|
| 694 |
+
|
| 695 |
+
```bash
|
| 696 |
+
streamlit run app.py
|
| 697 |
+
```
|
| 698 |
+
|
| 699 |
+
---
|
| 700 |
+
|
| 701 |
+
## 12. Project Structure
|
| 702 |
+
|
| 703 |
+
```
|
| 704 |
+
project_02/
|
| 705 |
+
├── app.py # Streamlit demo (3 tabs: Caption, Compare, Results)
|
| 706 |
+
├── config.py # Backward-compatible config wrapper
|
| 707 |
+
├── data_prep.py # Dataset loading + caption filtering strategies
|
| 708 |
+
├── eval.py # Unified CIDEr evaluator + experiment runner
|
| 709 |
+
├── train.py # Unified training loop for all 4 models
|
| 710 |
+
├── requirements.txt # Python dependencies
|
| 711 |
+
├── input.txt # Shakespeare corpus (character vocabulary source)
|
| 712 |
+
├── shakespeare_transformer.pt # Pre-trained Shakespeare decoder weights
|
| 713 |
+
│
|
| 714 |
+
├── configs/
|
| 715 |
+
│ ├── __init__.py # get_config() factory function
|
| 716 |
+
│ ├── base_config.py # Shared hyperparameters for all models
|
| 717 |
+
│ ├── blip_config.py # BLIP-specific settings
|
| 718 |
+
│ ├── vit_gpt2_config.py # ViT-GPT2-specific settings
|
| 719 |
+
│ ├── git_config.py # GIT-specific settings
|
| 720 |
+
│ └── custom_vlm_config.py # Custom VLM-specific settings
|
| 721 |
+
│
|
| 722 |
+
├── models/
|
| 723 |
+
│ ├── blip_tuner.py # BLIP: gated cross-attention
|
| 724 |
+
│ ├── vit_gpt2_tuner.py # ViT-GPT2: full cross-attention
|
| 725 |
+
│ ├── git_tuner.py # GIT: zero cross-attention
|
| 726 |
+
│ └── custom_vlm.py # Custom VLM: visual prefix-tuning
|
| 727 |
+
│
|
| 728 |
+
├── experiments/
|
| 729 |
+
│ ├── ablation_study.py # 4-mode attention masking experiment
|
| 730 |
+
│ ├── parameter_sweep.py # Beam/penalty/token sweep
|
| 731 |
+
│ ├── cross_attention_patterns.py # Architecture comparison
|
| 732 |
+
│ ├── data_prep_analysis.py # Caption filtering analysis
|
| 733 |
+
│ ├── results_cross_attention_masking_impact_on_caption_quality.md # Masking experiment results
|
| 734 |
+
│ ├── results_beam_search_and_decoding_settings_comparison.md # Sweep results
|
| 735 |
+
│ └── results_caption_filtering_strategy_comparison.md # Filtering results
|
| 736 |
+
│
|
| 737 |
+
├── outputs/
|
| 738 |
+
│ ├── blip/{best,latest}/ # BLIP checkpoints
|
| 739 |
+
│ └── custom_vlm/{best,latest}/ # Custom VLM checkpoints
|
| 740 |
+
│
|
| 741 |
+
└── README.md # Project overview and setup guide
|
| 742 |
+
```
|
| 743 |
+
|
| 744 |
+
---
|
| 745 |
+
|
| 746 |
+
**Technologies Used:** Python 3.9+, PyTorch, HuggingFace Transformers, HuggingFace Datasets, Streamlit, pycocoevalcap (CIDEr evaluation), detoxify (toxicity filtering), Pillow, NumPy, tqdm, accelerate.
|
| 747 |
+
|
| 748 |
+
**Hardware:** Apple Silicon Mac with MPS (Metal Performance Shaders) acceleration.
|
eval.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
eval.py
|
| 3 |
+
=======
|
| 4 |
+
Unified Evaluator — CIDEr across all four VLM architectures.
|
| 5 |
+
|
| 6 |
+
This module:
|
| 7 |
+
1. Evaluates each model's baseline CIDEr on the COCO validation set
|
| 8 |
+
2. Delegates ablation studies to experiments/ablation_study.py
|
| 9 |
+
3. Provides a unified cross-model comparison table
|
| 10 |
+
|
| 11 |
+
Weight Selection (--weights flag):
|
| 12 |
+
base → Use pretrained HuggingFace weights (no fine-tuning)
|
| 13 |
+
finetuned → Load from outputs/{model}/latest/
|
| 14 |
+
best → Load from outputs/{model}/best/
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
python eval.py # BLIP base weights
|
| 18 |
+
python eval.py --model blip --weights best # BLIP best fine-tuned
|
| 19 |
+
python eval.py --model all # All 4 models
|
| 20 |
+
python eval.py --model all --weights best # All 4 models, best weights
|
| 21 |
+
python eval.py --ablation # BLIP 4-mode ablation
|
| 22 |
+
python eval.py --sweep # Decoding parameter sweep
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import argparse
|
| 27 |
+
import torch
|
| 28 |
+
from typing import Optional
|
| 29 |
+
from tqdm.auto import tqdm
|
| 30 |
+
from pycocoevalcap.cider.cider import Cider
|
| 31 |
+
|
| 32 |
+
from config import CFG
|
| 33 |
+
from data_prep import get_dataloaders, get_dataloaders_for_model
|
| 34 |
+
from models.blip_tuner import get_blip_model, load_ckpt, generate_with_mask
|
| 35 |
+
from experiments.ablation_study import run_ablation_study
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 39 |
+
# Device Helper
|
| 40 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 41 |
+
|
| 42 |
+
def get_device():
|
| 43 |
+
if torch.backends.mps.is_available():
|
| 44 |
+
return torch.device("mps")
|
| 45 |
+
elif torch.cuda.is_available():
|
| 46 |
+
return torch.device("cuda")
|
| 47 |
+
return torch.device("cpu")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 51 |
+
# Weight Loading Helpers
|
| 52 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 53 |
+
|
| 54 |
+
def get_weights_dir(cfg, model_name: str, weights: str) -> Optional[str]:
|
| 55 |
+
"""
|
| 56 |
+
Return the checkpoint directory for the given model and weight selection.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
cfg : CFG instance
|
| 60 |
+
model_name : 'blip', 'vit_gpt2', 'git', 'custom'
|
| 61 |
+
weights : 'base', 'finetuned', 'best'
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Absolute path to checkpoint dir, or None for base weights.
|
| 65 |
+
"""
|
| 66 |
+
if weights == "base":
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
subdir = "latest" if weights == "finetuned" else "best"
|
| 70 |
+
path = os.path.join(cfg.output_root, model_name, subdir)
|
| 71 |
+
|
| 72 |
+
if os.path.isdir(path) and os.listdir(path):
|
| 73 |
+
return path
|
| 74 |
+
|
| 75 |
+
print(f"⚠️ No {subdir} checkpoint found at {path}. Falling back to base weights.")
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def print_weights_banner(model_name: str, weights: str, ckpt_dir: Optional[str]):
|
| 80 |
+
"""Print a clear banner showing which weights are being used."""
|
| 81 |
+
print("=" * 60)
|
| 82 |
+
print(f" Model: {model_name}")
|
| 83 |
+
if ckpt_dir:
|
| 84 |
+
print(f" Weights: {weights} → {ckpt_dir}")
|
| 85 |
+
else:
|
| 86 |
+
print(f" Weights: base (pretrained, no fine-tuning)")
|
| 87 |
+
print("=" * 60)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 91 |
+
# BLIP Baseline CIDEr Evaluation
|
| 92 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 93 |
+
|
| 94 |
+
def evaluate_blip(model, processor, dataloader, device,
|
| 95 |
+
num_beams=4, max_new_tokens=32, length_penalty=1.0,
|
| 96 |
+
eval_batches=25):
|
| 97 |
+
"""Evaluate BLIP CIDEr score (full attention — no ablation masking)."""
|
| 98 |
+
model.eval()
|
| 99 |
+
gts, res = {}, {}
|
| 100 |
+
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
for i, batch in enumerate(tqdm(dataloader, desc="Eval [BLIP]")):
|
| 103 |
+
if i >= eval_batches:
|
| 104 |
+
break
|
| 105 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 106 |
+
B = pixel_values.shape[0]
|
| 107 |
+
mask = torch.ones(B, 197, dtype=torch.long, device=device)
|
| 108 |
+
|
| 109 |
+
decoded = generate_with_mask(
|
| 110 |
+
model, processor, device=device,
|
| 111 |
+
pixel_values=pixel_values,
|
| 112 |
+
encoder_attention_mask=mask,
|
| 113 |
+
max_new_tokens=max_new_tokens,
|
| 114 |
+
num_beams=num_beams,
|
| 115 |
+
)
|
| 116 |
+
preds = decoded # generate_with_mask already returns decoded strings
|
| 117 |
+
labels = batch["labels"].clone()
|
| 118 |
+
gt_texts = processor.batch_decode(labels, skip_special_tokens=True)
|
| 119 |
+
|
| 120 |
+
for j, (p, g) in enumerate(zip(preds, gt_texts)):
|
| 121 |
+
k = str(i * len(preds) + j)
|
| 122 |
+
res[k] = [p]
|
| 123 |
+
gts[k] = [g]
|
| 124 |
+
|
| 125 |
+
if not gts:
|
| 126 |
+
return 0.0
|
| 127 |
+
scorer = Cider()
|
| 128 |
+
score, _ = scorer.compute_score(gts, res)
|
| 129 |
+
print(f" ✅ CIDEr [BLIP]: {score:.4f}")
|
| 130 |
+
return score
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 134 |
+
# ViT-GPT2 CIDEr Evaluation
|
| 135 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 136 |
+
|
| 137 |
+
def evaluate_vit_gpt2(model, tokenizer, dataloader, device,
|
| 138 |
+
num_beams=4, max_new_tokens=32, length_penalty=1.0,
|
| 139 |
+
eval_batches=25):
|
| 140 |
+
"""Evaluate ViT-GPT2 CIDEr score."""
|
| 141 |
+
model.eval()
|
| 142 |
+
gts, res = {}, {}
|
| 143 |
+
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
for i, batch in enumerate(tqdm(dataloader, desc="Eval [ViT-GPT2]")):
|
| 146 |
+
if i >= eval_batches:
|
| 147 |
+
break
|
| 148 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 149 |
+
out = model.generate(
|
| 150 |
+
pixel_values=pixel_values,
|
| 151 |
+
num_beams=num_beams,
|
| 152 |
+
max_new_tokens=max_new_tokens,
|
| 153 |
+
length_penalty=length_penalty,
|
| 154 |
+
)
|
| 155 |
+
preds = [tokenizer.decode(ids, skip_special_tokens=True) for ids in out]
|
| 156 |
+
labels = batch["labels"].clone()
|
| 157 |
+
labels[labels == -100] = tokenizer.pad_token_id
|
| 158 |
+
gt_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 159 |
+
|
| 160 |
+
for j, (p, g) in enumerate(zip(preds, gt_texts)):
|
| 161 |
+
k = str(i * len(preds) + j)
|
| 162 |
+
res[k] = [p]
|
| 163 |
+
gts[k] = [g]
|
| 164 |
+
|
| 165 |
+
if not gts:
|
| 166 |
+
return 0.0
|
| 167 |
+
scorer = Cider()
|
| 168 |
+
score, _ = scorer.compute_score(gts, res)
|
| 169 |
+
print(f" ✅ CIDEr [ViT-GPT2]: {score:.4f}")
|
| 170 |
+
return score
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 174 |
+
# GIT CIDEr Evaluation
|
| 175 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 176 |
+
|
| 177 |
+
def evaluate_git(model, processor, dataloader, device,
|
| 178 |
+
num_beams=4, max_new_tokens=32, length_penalty=1.0,
|
| 179 |
+
eval_batches=25):
|
| 180 |
+
"""Evaluate GIT CIDEr score."""
|
| 181 |
+
model.eval()
|
| 182 |
+
gts, res = {}, {}
|
| 183 |
+
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
for i, batch in enumerate(tqdm(dataloader, desc="Eval [GIT]")):
|
| 186 |
+
if i >= eval_batches:
|
| 187 |
+
break
|
| 188 |
+
inputs = {k: v.to(device) for k, v in batch.items()
|
| 189 |
+
if k in ("pixel_values", "input_ids", "attention_mask")}
|
| 190 |
+
out = model.generate(
|
| 191 |
+
**inputs,
|
| 192 |
+
num_beams=num_beams,
|
| 193 |
+
max_new_tokens=max_new_tokens,
|
| 194 |
+
length_penalty=length_penalty,
|
| 195 |
+
)
|
| 196 |
+
preds = processor.batch_decode(out, skip_special_tokens=True)
|
| 197 |
+
labels = batch["labels"].clone()
|
| 198 |
+
labels[labels == -100] = processor.tokenizer.pad_token_id
|
| 199 |
+
gt_texts = processor.batch_decode(labels, skip_special_tokens=True)
|
| 200 |
+
|
| 201 |
+
for j, (p, g) in enumerate(zip(preds, gt_texts)):
|
| 202 |
+
k = str(i * len(preds) + j)
|
| 203 |
+
res[k] = [p]
|
| 204 |
+
gts[k] = [g]
|
| 205 |
+
|
| 206 |
+
if not gts:
|
| 207 |
+
return 0.0
|
| 208 |
+
scorer = Cider()
|
| 209 |
+
score, _ = scorer.compute_score(gts, res)
|
| 210 |
+
print(f" ✅ CIDEr [GIT]: {score:.4f}")
|
| 211 |
+
return score
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 215 |
+
# Custom VLM CIDEr Evaluation
|
| 216 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 217 |
+
|
| 218 |
+
def evaluate_custom_vlm_cider(model, val_loader, device,
|
| 219 |
+
char_to_idx, idx_to_char,
|
| 220 |
+
max_new_tokens=80, num_beams=1,
|
| 221 |
+
length_penalty=1.0,
|
| 222 |
+
eval_batches=20):
|
| 223 |
+
"""Evaluate CIDEr score for the CustomVLM using autoregressive generation."""
|
| 224 |
+
model.eval()
|
| 225 |
+
gts, res = {}, {}
|
| 226 |
+
|
| 227 |
+
print("\nEvaluating Custom VLM (Visual Prefix-Tuning)...")
|
| 228 |
+
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
for i, batch in enumerate(tqdm(val_loader, desc="Eval [CustomVLM]")):
|
| 231 |
+
if i >= eval_batches:
|
| 232 |
+
break
|
| 233 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 234 |
+
B = pixel_values.shape[0]
|
| 235 |
+
|
| 236 |
+
for b in range(B):
|
| 237 |
+
pv_single = pixel_values[b:b+1]
|
| 238 |
+
|
| 239 |
+
if num_beams > 1:
|
| 240 |
+
pred = model.generate_beam(
|
| 241 |
+
pv_single, char_to_idx, idx_to_char,
|
| 242 |
+
max_new_tokens=max_new_tokens,
|
| 243 |
+
num_beams=num_beams,
|
| 244 |
+
length_penalty=length_penalty,
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
pred = model.generate(
|
| 248 |
+
pv_single, char_to_idx, idx_to_char,
|
| 249 |
+
max_new_tokens=max_new_tokens,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
tgt_ids = batch["text_targets"][b].tolist()
|
| 253 |
+
gt_text = "".join(idx_to_char.get(idx, "") for idx in tgt_ids if idx != 0)
|
| 254 |
+
|
| 255 |
+
idx_key = str(i * B + b)
|
| 256 |
+
res[idx_key] = [pred.strip()]
|
| 257 |
+
gts[idx_key] = [gt_text.strip()]
|
| 258 |
+
|
| 259 |
+
if not gts:
|
| 260 |
+
return 0.0
|
| 261 |
+
|
| 262 |
+
scorer = Cider()
|
| 263 |
+
score, _ = scorer.compute_score(gts, res)
|
| 264 |
+
print(f" ✅ CIDEr [CustomVLM]: {score:.4f}")
|
| 265 |
+
return score
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 269 |
+
# Custom VLM Loader (with weight selection)
|
| 270 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 271 |
+
|
| 272 |
+
def load_custom_vlm_for_eval(cfg, device, weights="base"):
|
| 273 |
+
"""
|
| 274 |
+
Load CustomVLM with the specified weight selection.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
weights: 'base' (Shakespeare only), 'finetuned' (latest ckpt), 'best' (best ckpt)
|
| 278 |
+
"""
|
| 279 |
+
from models.custom_vlm import CustomVLM, build_char_vocab
|
| 280 |
+
from data_prep import get_custom_vlm_dataloader
|
| 281 |
+
|
| 282 |
+
with open(cfg.shakespeare_file, "r") as f:
|
| 283 |
+
text = f.read()
|
| 284 |
+
_, c2i, i2c, vs = build_char_vocab(text)
|
| 285 |
+
|
| 286 |
+
model = CustomVLM(
|
| 287 |
+
vocab_size=vs,
|
| 288 |
+
text_embed_dim=cfg.text_embed_dim,
|
| 289 |
+
n_heads=cfg.n_heads,
|
| 290 |
+
n_layers=cfg.n_layers,
|
| 291 |
+
block_size=cfg.block_size,
|
| 292 |
+
dropout=cfg.dropout,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Always load Shakespeare weights first
|
| 296 |
+
if os.path.exists(cfg.shakespeare_weights_path):
|
| 297 |
+
model.load_shakespeare_weights(cfg.shakespeare_weights_path)
|
| 298 |
+
|
| 299 |
+
# Then optionally load fine-tuned weights on top
|
| 300 |
+
ckpt_dir = get_weights_dir(cfg, "custom_vlm", weights)
|
| 301 |
+
if ckpt_dir:
|
| 302 |
+
ckpt_path = os.path.join(ckpt_dir, "custom_vlm.pt")
|
| 303 |
+
if os.path.exists(ckpt_path):
|
| 304 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
| 305 |
+
# Filter shape mismatches gracefully
|
| 306 |
+
own_state = model.state_dict()
|
| 307 |
+
filtered = {k: v for k, v in state["model_state"].items()
|
| 308 |
+
if k in own_state and own_state[k].shape == v.shape}
|
| 309 |
+
model.load_state_dict(filtered, strict=False)
|
| 310 |
+
print(f" ✅ Loaded fine-tuned weights from {ckpt_path}")
|
| 311 |
+
|
| 312 |
+
print_weights_banner("Custom VLM", weights, ckpt_dir)
|
| 313 |
+
model.to(device).eval()
|
| 314 |
+
|
| 315 |
+
_, val_loader = get_custom_vlm_dataloader(cfg, c2i)
|
| 316 |
+
return model, c2i, i2c, val_loader
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 320 |
+
# All-Model Comparison Table
|
| 321 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 322 |
+
|
| 323 |
+
def evaluate_all_models(cfg, device, weights="base",
|
| 324 |
+
num_beams=4, max_new_tokens=32,
|
| 325 |
+
length_penalty=1.0, eval_batches=25):
|
| 326 |
+
"""Run CIDEr evaluation for all four models and print a comparison table."""
|
| 327 |
+
results = {}
|
| 328 |
+
|
| 329 |
+
# ── BLIP ────────────────────────────────────────────────────────────────
|
| 330 |
+
print("\n[1/4] Evaluating BLIP...")
|
| 331 |
+
blip_cfg = CFG.load_for_model("blip")
|
| 332 |
+
model_b, proc_b = get_blip_model(blip_cfg, device)
|
| 333 |
+
ckpt = get_weights_dir(blip_cfg, "blip", weights)
|
| 334 |
+
if ckpt:
|
| 335 |
+
load_ckpt(model_b, None, None, ckpt)
|
| 336 |
+
print_weights_banner("BLIP", weights, ckpt)
|
| 337 |
+
_, val_b = get_dataloaders(blip_cfg, proc_b)
|
| 338 |
+
results["BLIP"] = evaluate_blip(
|
| 339 |
+
model_b, proc_b, val_b, device,
|
| 340 |
+
num_beams=num_beams, max_new_tokens=max_new_tokens,
|
| 341 |
+
length_penalty=length_penalty, eval_batches=eval_batches,
|
| 342 |
+
)
|
| 343 |
+
del model_b, proc_b
|
| 344 |
+
|
| 345 |
+
# ── ViT-GPT2 ────────────────────────────────────────────────────────────
|
| 346 |
+
print("\n[2/4] Evaluating ViT-GPT2...")
|
| 347 |
+
from models.vit_gpt2_tuner import get_vit_gpt2_model
|
| 348 |
+
vg2_cfg = CFG.load_for_model("vit_gpt2")
|
| 349 |
+
model_v, proc_v, tok_v = get_vit_gpt2_model(vg2_cfg, device)
|
| 350 |
+
ckpt = get_weights_dir(vg2_cfg, "vit_gpt2", weights)
|
| 351 |
+
if ckpt:
|
| 352 |
+
from transformers import VisionEncoderDecoderModel
|
| 353 |
+
finetuned = VisionEncoderDecoderModel.from_pretrained(ckpt)
|
| 354 |
+
model_v.load_state_dict(finetuned.state_dict())
|
| 355 |
+
model_v.to(device)
|
| 356 |
+
print_weights_banner("ViT-GPT2", weights, ckpt)
|
| 357 |
+
_, val_v = get_dataloaders_for_model(vg2_cfg, "vit_gpt2", proc_v, tok_v)
|
| 358 |
+
results["ViT-GPT2"] = evaluate_vit_gpt2(
|
| 359 |
+
model_v, tok_v, val_v, device,
|
| 360 |
+
num_beams=num_beams, max_new_tokens=max_new_tokens,
|
| 361 |
+
length_penalty=length_penalty, eval_batches=eval_batches,
|
| 362 |
+
)
|
| 363 |
+
del model_v, proc_v, tok_v
|
| 364 |
+
|
| 365 |
+
# ── GIT ─────────────────────────────────────────────────────────────────
|
| 366 |
+
print("\n[3/4] Evaluating GIT...")
|
| 367 |
+
from models.git_tuner import get_git_model
|
| 368 |
+
git_cfg = CFG.load_for_model("git")
|
| 369 |
+
model_g, proc_g = get_git_model(git_cfg, device)
|
| 370 |
+
ckpt = get_weights_dir(git_cfg, "git", weights)
|
| 371 |
+
if ckpt:
|
| 372 |
+
from transformers import AutoModelForCausalLM
|
| 373 |
+
finetuned = AutoModelForCausalLM.from_pretrained(ckpt)
|
| 374 |
+
model_g.load_state_dict(finetuned.state_dict())
|
| 375 |
+
model_g.to(device)
|
| 376 |
+
print_weights_banner("GIT", weights, ckpt)
|
| 377 |
+
_, val_g = get_dataloaders_for_model(git_cfg, "git", proc_g)
|
| 378 |
+
results["GIT"] = evaluate_git(
|
| 379 |
+
model_g, proc_g, val_g, device,
|
| 380 |
+
num_beams=num_beams, max_new_tokens=max_new_tokens,
|
| 381 |
+
length_penalty=length_penalty, eval_batches=eval_batches,
|
| 382 |
+
)
|
| 383 |
+
del model_g, proc_g
|
| 384 |
+
|
| 385 |
+
# ── Custom VLM ──────────────────────────────────────────────────────────
|
| 386 |
+
print("\n[4/4] Evaluating Custom VLM...")
|
| 387 |
+
vlm_cfg = CFG.load_for_model("custom")
|
| 388 |
+
model_c, c2i, i2c, val_c = load_custom_vlm_for_eval(vlm_cfg, device, weights)
|
| 389 |
+
results["CustomVLM"] = evaluate_custom_vlm_cider(
|
| 390 |
+
model_c, val_c, device, c2i, i2c,
|
| 391 |
+
max_new_tokens=80, eval_batches=15,
|
| 392 |
+
)
|
| 393 |
+
del model_c
|
| 394 |
+
|
| 395 |
+
# ── Summary Table ────────────────────────────────────────────────────────
|
| 396 |
+
print("\n")
|
| 397 |
+
print("=" * 65)
|
| 398 |
+
print(f" All-Model CIDEr Comparison | Weights: {weights}")
|
| 399 |
+
print(f" Beams={num_beams} MaxTok={max_new_tokens} LenPen={length_penalty}")
|
| 400 |
+
print("=" * 65)
|
| 401 |
+
print(f" {'Architecture':<22} {'CIDEr':>8} {'CA Type'}")
|
| 402 |
+
print(" " + "-" * 61)
|
| 403 |
+
ca_types = {
|
| 404 |
+
"BLIP": "Gated MED cross-attention",
|
| 405 |
+
"ViT-GPT2": "Standard full cross-attention",
|
| 406 |
+
"GIT": "Self-attention prefix (no CA)",
|
| 407 |
+
"CustomVLM": "Linear bridge prefix (no CA)",
|
| 408 |
+
}
|
| 409 |
+
for name, score in sorted(results.items(), key=lambda x: -x[1]):
|
| 410 |
+
print(f" {name:<22} {score:>8.4f} {ca_types.get(name, '')}")
|
| 411 |
+
print("=" * 65)
|
| 412 |
+
|
| 413 |
+
return results
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 417 |
+
# Main
|
| 418 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 419 |
+
|
| 420 |
+
def main():
|
| 421 |
+
parser = argparse.ArgumentParser(description="Unified VLM Evaluator")
|
| 422 |
+
parser.add_argument(
|
| 423 |
+
"--model", type=str, default="blip",
|
| 424 |
+
choices=["blip", "vit_gpt2", "git", "custom", "all"],
|
| 425 |
+
help="Which model(s) to evaluate",
|
| 426 |
+
)
|
| 427 |
+
parser.add_argument(
|
| 428 |
+
"--weights", type=str, default="base",
|
| 429 |
+
choices=["base", "finetuned", "best"],
|
| 430 |
+
help="Which weights to use: base (pretrained), finetuned (latest/), best (best/)",
|
| 431 |
+
)
|
| 432 |
+
parser.add_argument("--ablation", action="store_true",
|
| 433 |
+
help="Run BLIP 4-mode cross-attention ablation study")
|
| 434 |
+
parser.add_argument("--sweep", action="store_true",
|
| 435 |
+
help="Run decoding parameter sweep")
|
| 436 |
+
parser.add_argument("--num_beams", type=int, default=10)
|
| 437 |
+
parser.add_argument("--max_new_tokens", type=int, default=50)
|
| 438 |
+
parser.add_argument("--length_penalty", type=float, default=1.2)
|
| 439 |
+
parser.add_argument("--eval_batches", type=int, default=25)
|
| 440 |
+
args = parser.parse_args()
|
| 441 |
+
|
| 442 |
+
device = get_device()
|
| 443 |
+
print(f"✅ Device: {device}")
|
| 444 |
+
|
| 445 |
+
if args.model == "all":
|
| 446 |
+
cfg = CFG.load_for_model("blip")
|
| 447 |
+
evaluate_all_models(
|
| 448 |
+
cfg, device,
|
| 449 |
+
weights=args.weights,
|
| 450 |
+
num_beams=args.num_beams,
|
| 451 |
+
max_new_tokens=args.max_new_tokens,
|
| 452 |
+
length_penalty=args.length_penalty,
|
| 453 |
+
eval_batches=args.eval_batches,
|
| 454 |
+
)
|
| 455 |
+
return
|
| 456 |
+
|
| 457 |
+
cfg = CFG.load_for_model(args.model)
|
| 458 |
+
|
| 459 |
+
if args.model == "blip" or args.ablation:
|
| 460 |
+
model, processor = get_blip_model(cfg, device)
|
| 461 |
+
|
| 462 |
+
ckpt_dir = get_weights_dir(cfg, "blip", args.weights)
|
| 463 |
+
if ckpt_dir:
|
| 464 |
+
load_ckpt(model, None, None, ckpt_dir)
|
| 465 |
+
print_weights_banner("BLIP", args.weights, ckpt_dir)
|
| 466 |
+
|
| 467 |
+
_, val_loader = get_dataloaders(cfg, processor)
|
| 468 |
+
|
| 469 |
+
if args.ablation:
|
| 470 |
+
run_ablation_study(
|
| 471 |
+
model, processor, val_loader, device, cfg,
|
| 472 |
+
num_beams=args.num_beams,
|
| 473 |
+
max_new_tokens=args.max_new_tokens,
|
| 474 |
+
length_penalty=args.length_penalty,
|
| 475 |
+
eval_batches=args.eval_batches,
|
| 476 |
+
)
|
| 477 |
+
elif args.sweep:
|
| 478 |
+
from experiments.parameter_sweep import run_parameter_sweep
|
| 479 |
+
run_parameter_sweep(
|
| 480 |
+
"blip",
|
| 481 |
+
{"model": model, "processor": processor},
|
| 482 |
+
val_loader, device,
|
| 483 |
+
eval_batches=args.eval_batches,
|
| 484 |
+
)
|
| 485 |
+
else:
|
| 486 |
+
evaluate_blip(
|
| 487 |
+
model, processor, val_loader, device,
|
| 488 |
+
num_beams=args.num_beams,
|
| 489 |
+
max_new_tokens=args.max_new_tokens,
|
| 490 |
+
length_penalty=args.length_penalty,
|
| 491 |
+
eval_batches=args.eval_batches,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
elif args.model == "vit_gpt2":
|
| 495 |
+
from models.vit_gpt2_tuner import get_vit_gpt2_model
|
| 496 |
+
model, processor, tokenizer = get_vit_gpt2_model(cfg, device)
|
| 497 |
+
ckpt_dir = get_weights_dir(cfg, "vit_gpt2", args.weights)
|
| 498 |
+
if ckpt_dir:
|
| 499 |
+
from transformers import VisionEncoderDecoderModel
|
| 500 |
+
finetuned = VisionEncoderDecoderModel.from_pretrained(ckpt_dir)
|
| 501 |
+
model.load_state_dict(finetuned.state_dict())
|
| 502 |
+
model.to(device)
|
| 503 |
+
print_weights_banner("ViT-GPT2", args.weights, ckpt_dir)
|
| 504 |
+
_, val_loader = get_dataloaders_for_model(cfg, "vit_gpt2", processor, tokenizer)
|
| 505 |
+
evaluate_vit_gpt2(
|
| 506 |
+
model, tokenizer, val_loader, device,
|
| 507 |
+
num_beams=args.num_beams,
|
| 508 |
+
max_new_tokens=args.max_new_tokens,
|
| 509 |
+
length_penalty=args.length_penalty,
|
| 510 |
+
eval_batches=args.eval_batches,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
elif args.model == "git":
|
| 514 |
+
from models.git_tuner import get_git_model
|
| 515 |
+
model, processor = get_git_model(cfg, device)
|
| 516 |
+
ckpt_dir = get_weights_dir(cfg, "git", args.weights)
|
| 517 |
+
if ckpt_dir:
|
| 518 |
+
from transformers import AutoModelForCausalLM
|
| 519 |
+
finetuned = AutoModelForCausalLM.from_pretrained(ckpt_dir)
|
| 520 |
+
model.load_state_dict(finetuned.state_dict())
|
| 521 |
+
model.to(device)
|
| 522 |
+
print_weights_banner("GIT", args.weights, ckpt_dir)
|
| 523 |
+
_, val_loader = get_dataloaders_for_model(cfg, "git", processor)
|
| 524 |
+
evaluate_git(
|
| 525 |
+
model, processor, val_loader, device,
|
| 526 |
+
num_beams=args.num_beams,
|
| 527 |
+
max_new_tokens=args.max_new_tokens,
|
| 528 |
+
length_penalty=args.length_penalty,
|
| 529 |
+
eval_batches=args.eval_batches,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
elif args.model == "custom":
|
| 533 |
+
vlm_cfg = CFG.load_for_model("custom")
|
| 534 |
+
model, c2i, i2c, val_loader = load_custom_vlm_for_eval(
|
| 535 |
+
vlm_cfg, device, args.weights)
|
| 536 |
+
evaluate_custom_vlm_cider(
|
| 537 |
+
model, val_loader, device, c2i, i2c,
|
| 538 |
+
max_new_tokens=80,
|
| 539 |
+
num_beams=args.num_beams,
|
| 540 |
+
length_penalty=args.length_penalty,
|
| 541 |
+
eval_batches=args.eval_batches,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
if __name__ == "__main__":
|
| 546 |
+
main()
|
experiments/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
experiments/
|
| 3 |
+
============
|
| 4 |
+
Pluggable experiment modules for the VLM Image Captioning pipeline.
|
| 5 |
+
|
| 6 |
+
ablation_study.py — Cross-attention mask ablation (BLIP / ViT-GPT2)
|
| 7 |
+
cross_attention_patterns.py — Architecture comparison table
|
| 8 |
+
parameter_sweep.py — beam_size / length_penalty / max_length sweep
|
| 9 |
+
data_prep_analysis.py — Before vs after caption quality filtering
|
| 10 |
+
vqa_experiment.py — Visual Question Answering demo (BLIP-VQA)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from .ablation_study import (
|
| 14 |
+
build_ablation_mask,
|
| 15 |
+
evaluate_blip_ablation,
|
| 16 |
+
run_ablation_study,
|
| 17 |
+
ABLATION_MODES,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"build_ablation_mask",
|
| 22 |
+
"evaluate_blip_ablation",
|
| 23 |
+
"run_ablation_study",
|
| 24 |
+
"ABLATION_MODES",
|
| 25 |
+
]
|
experiments/ablation_study.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
experiments/ablation_study.py
|
| 3 |
+
==============================
|
| 4 |
+
Cross-Attention Masking Ablation Study for BLIP and ViT-GPT2.
|
| 5 |
+
|
| 6 |
+
Four encoder_attention_mask ablation modes:
|
| 7 |
+
|
| 8 |
+
Mode 1 — Baseline (Full Attention)
|
| 9 |
+
Mask : all 1s → text decoder sees all 197 patches (1 CLS + 196 spatial)
|
| 10 |
+
Intent: Upper-bound reference; no information is hidden.
|
| 11 |
+
|
| 12 |
+
Mode 2 — Random Patch Dropout (Sparse Attention)
|
| 13 |
+
Mask : 50% of 196 spatial patches randomly zeroed; CLS always kept at idx 0
|
| 14 |
+
Intent: Tests redundancy — how much spatial information is truly needed?
|
| 15 |
+
|
| 16 |
+
Mode 3 — Center-Focus Spatial Cropping
|
| 17 |
+
Mask : Only the inner 8×8 grid of the 14×14 spatial patch grid kept
|
| 18 |
+
Intent: Tests whether the image periphery (background clutter) hurts captions.
|
| 19 |
+
|
| 20 |
+
Mode 4 — "The Squint" (Global Pooling Proxy)
|
| 21 |
+
Mask : 196 spatial patches averaged → 1 token appended after CLS
|
| 22 |
+
The mask then has shape (1, 2): [CLS=1, global_pool=1]
|
| 23 |
+
Intent: Tests whether granular local patch details are necessary, or a
|
| 24 |
+
global compressed summary suffices.
|
| 25 |
+
|
| 26 |
+
Note: GIT does not support encoder_attention_mask (no cross-attention).
|
| 27 |
+
GIT ablations are noted as N/A in the results table.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import os
|
| 31 |
+
import sys
|
| 32 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
from tqdm.auto import tqdm
|
| 36 |
+
from pycocoevalcap.cider.cider import Cider
|
| 37 |
+
from models.blip_tuner import generate_with_mask
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 41 |
+
# Available Modes
|
| 42 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 43 |
+
|
| 44 |
+
ABLATION_MODES = ["baseline", "random_dropout", "center_focus", "squint"]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 48 |
+
# Ablation Mask Builders
|
| 49 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 50 |
+
|
| 51 |
+
def build_ablation_mask(mode: str, batch_size: int, num_patches: int,
|
| 52 |
+
device: torch.device, cfg=None):
|
| 53 |
+
"""
|
| 54 |
+
Build an encoder_attention_mask tensor for a given ablation mode.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
mode : 'baseline' | 'random_dropout' | 'center_focus' | 'squint'
|
| 58 |
+
batch_size : number of images in the batch
|
| 59 |
+
num_patches : total patches including CLS (usually 197 = 1 + 196)
|
| 60 |
+
device : target torch device
|
| 61 |
+
cfg : config object for dropout_ratio (default 0.5 if None)
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
mask : LongTensor of shape (batch_size, num_patches)
|
| 65 |
+
Squint returns shape (batch_size, 2) — handled separately.
|
| 66 |
+
"""
|
| 67 |
+
B = batch_size
|
| 68 |
+
N = num_patches
|
| 69 |
+
spatial = N - 1 # 196 spatial patches (excluding CLS at index 0)
|
| 70 |
+
dropout_ratio = cfg.dropout_ratio if cfg else 0.5
|
| 71 |
+
|
| 72 |
+
if mode == "baseline":
|
| 73 |
+
# ── Mode 1: Full attention — all 197 patches visible ─────────────────
|
| 74 |
+
return torch.ones(B, N, dtype=torch.long, device=device)
|
| 75 |
+
|
| 76 |
+
elif mode == "random_dropout":
|
| 77 |
+
# ── Mode 2: Randomly zero 50% of spatial patches; keep CLS ──────────
|
| 78 |
+
mask = torch.ones(B, N, dtype=torch.long, device=device)
|
| 79 |
+
n_drop = int(spatial * dropout_ratio)
|
| 80 |
+
for b in range(B):
|
| 81 |
+
drop_indices = torch.randperm(spatial, device=device)[:n_drop] + 1
|
| 82 |
+
mask[b, drop_indices] = 0
|
| 83 |
+
return mask
|
| 84 |
+
|
| 85 |
+
elif mode == "center_focus":
|
| 86 |
+
# ── Mode 3: Keep only the inner 8×8 of the 14×14 spatial grid ────────
|
| 87 |
+
GRID = 14
|
| 88 |
+
INNER = 8
|
| 89 |
+
offset = (GRID - INNER) // 2 # 3
|
| 90 |
+
|
| 91 |
+
keep_indices = set()
|
| 92 |
+
for row in range(offset, offset + INNER):
|
| 93 |
+
for col in range(offset, offset + INNER):
|
| 94 |
+
keep_indices.add(row * GRID + col + 1) # +1 for CLS offset
|
| 95 |
+
|
| 96 |
+
mask = torch.zeros(B, N, dtype=torch.long, device=device)
|
| 97 |
+
mask[:, 0] = 1 # Always keep CLS
|
| 98 |
+
for idx in keep_indices:
|
| 99 |
+
if idx < N:
|
| 100 |
+
mask[:, idx] = 1
|
| 101 |
+
return mask
|
| 102 |
+
|
| 103 |
+
elif mode == "squint":
|
| 104 |
+
# ── Mode 4: Global Pooling Proxy ──────────────────────────────────────
|
| 105 |
+
# Returns a 2-token mask: [CLS=1, global_pool=1]
|
| 106 |
+
# The actual global pooling is handled in evaluate_blip_ablation().
|
| 107 |
+
return torch.ones(B, 2, dtype=torch.long, device=device)
|
| 108 |
+
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(
|
| 111 |
+
f"Unknown ablation mode: {mode!r}. "
|
| 112 |
+
"Choose from: baseline, random_dropout, center_focus, squint"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 117 |
+
# BLIP CIDEr Evaluation (single mode)
|
| 118 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 119 |
+
|
| 120 |
+
def evaluate_blip_ablation(model, processor, dataloader, device,
|
| 121 |
+
mode="baseline", cfg=None,
|
| 122 |
+
num_beams=4, max_new_tokens=32,
|
| 123 |
+
length_penalty=1.0, eval_batches=25):
|
| 124 |
+
"""
|
| 125 |
+
Evaluate BLIP CIDEr score for a specific ablation mode.
|
| 126 |
+
|
| 127 |
+
For 'squint' mode, we manually extract the visual encoder embeddings,
|
| 128 |
+
pool the spatial patches, and pass them as encoder_hidden_states directly.
|
| 129 |
+
For all other modes, we use generate_with_mask() with encoder_attention_mask.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
eval_batches : max number of batches to evaluate (keep small for speed)
|
| 133 |
+
length_penalty: passed to beam search (1.0 = neutral, >1 favors longer)
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
cider_score: float
|
| 137 |
+
"""
|
| 138 |
+
model.eval()
|
| 139 |
+
gts = {}
|
| 140 |
+
res = {}
|
| 141 |
+
|
| 142 |
+
print(f"\n{'='*60}")
|
| 143 |
+
print(f" Ablation Mode : {mode.upper()}")
|
| 144 |
+
print(f" Beams={num_beams} MaxTokens={max_new_tokens} LenPenalty={length_penalty}")
|
| 145 |
+
print(f"{'='*60}")
|
| 146 |
+
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
for i, batch in enumerate(tqdm(dataloader, desc=f"Eval [{mode}]")):
|
| 149 |
+
if i >= eval_batches:
|
| 150 |
+
break
|
| 151 |
+
|
| 152 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 153 |
+
B = pixel_values.shape[0]
|
| 154 |
+
|
| 155 |
+
if mode == "squint":
|
| 156 |
+
vision_outputs = model.vision_model(pixel_values=pixel_values)
|
| 157 |
+
hidden_states = vision_outputs.last_hidden_state # (B, 197, 768)
|
| 158 |
+
cls_token = hidden_states[:, :1, :]
|
| 159 |
+
spatial = hidden_states[:, 1:, :]
|
| 160 |
+
global_pool = spatial.mean(dim=1, keepdim=True)
|
| 161 |
+
pooled_hidden = torch.cat([cls_token, global_pool], dim=1)
|
| 162 |
+
|
| 163 |
+
decoded = generate_with_mask(
|
| 164 |
+
model, processor, device=device,
|
| 165 |
+
encoder_hidden_states=pooled_hidden,
|
| 166 |
+
encoder_attention_mask=torch.ones(B, 2, dtype=torch.long, device=device),
|
| 167 |
+
max_new_tokens=max_new_tokens,
|
| 168 |
+
num_beams=num_beams,
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
num_patches = 197
|
| 172 |
+
mask = build_ablation_mask(mode, B, num_patches, device, cfg)
|
| 173 |
+
decoded = generate_with_mask(
|
| 174 |
+
model, processor, device=device,
|
| 175 |
+
pixel_values=pixel_values,
|
| 176 |
+
encoder_attention_mask=mask,
|
| 177 |
+
max_new_tokens=max_new_tokens,
|
| 178 |
+
num_beams=num_beams,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
preds = decoded # generate_with_mask returns decoded strings
|
| 182 |
+
|
| 183 |
+
labels = batch["labels"].clone()
|
| 184 |
+
gts_batch = processor.batch_decode(labels, skip_special_tokens=True)
|
| 185 |
+
|
| 186 |
+
for j in range(len(preds)):
|
| 187 |
+
idx_key = str(i * len(preds) + j)
|
| 188 |
+
res[idx_key] = [preds[j]]
|
| 189 |
+
gts[idx_key] = [gts_batch[j]]
|
| 190 |
+
|
| 191 |
+
if not gts:
|
| 192 |
+
print("⚠️ No predictions gathered. Returning 0.")
|
| 193 |
+
return 0.0
|
| 194 |
+
|
| 195 |
+
cider_scorer = Cider()
|
| 196 |
+
score, _ = cider_scorer.compute_score(gts, res)
|
| 197 |
+
print(f" ✅ CIDEr [{mode}]: {score:.4f}")
|
| 198 |
+
return score
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 202 |
+
# Full Ablation Study
|
| 203 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 204 |
+
|
| 205 |
+
def run_ablation_study(model, processor, dataloader, device, cfg,
|
| 206 |
+
num_beams=4, max_new_tokens=32, length_penalty=1.0,
|
| 207 |
+
eval_batches=25):
|
| 208 |
+
"""
|
| 209 |
+
Run all 4 ablation modes and print a CIDEr comparison table.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
results: dict mapping mode → CIDEr score
|
| 213 |
+
"""
|
| 214 |
+
results = {}
|
| 215 |
+
for mode in ABLATION_MODES:
|
| 216 |
+
score = evaluate_blip_ablation(
|
| 217 |
+
model, processor, dataloader, device,
|
| 218 |
+
mode=mode, cfg=cfg,
|
| 219 |
+
num_beams=num_beams, max_new_tokens=max_new_tokens,
|
| 220 |
+
length_penalty=length_penalty,
|
| 221 |
+
eval_batches=eval_batches,
|
| 222 |
+
)
|
| 223 |
+
results[mode] = score
|
| 224 |
+
|
| 225 |
+
print("\n")
|
| 226 |
+
print("=" * 60)
|
| 227 |
+
print(" Cross-Attention Ablation Results (CIDEr)")
|
| 228 |
+
print(f" Beams={num_beams} MaxTokens={max_new_tokens} LenPenalty={length_penalty}")
|
| 229 |
+
print("=" * 60)
|
| 230 |
+
print(f" {'Mode':<25} {'CIDEr':>10} {'Δ Baseline':>12}")
|
| 231 |
+
print("-" * 60)
|
| 232 |
+
baseline_score = results.get("baseline", 0.0)
|
| 233 |
+
for mode, score in results.items():
|
| 234 |
+
delta = score - baseline_score
|
| 235 |
+
sign = "+" if delta >= 0 else ""
|
| 236 |
+
print(f" {mode:<25} {score:>10.4f} {sign}{delta:>11.4f}")
|
| 237 |
+
print("=" * 60)
|
| 238 |
+
print("=" * 60)
|
| 239 |
+
|
| 240 |
+
return results
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
import argparse
|
| 244 |
+
from config import CFG
|
| 245 |
+
from models.blip_tuner import get_blip_model
|
| 246 |
+
from torch.utils.data import DataLoader
|
| 247 |
+
from datasets import load_dataset
|
| 248 |
+
import aiohttp
|
| 249 |
+
|
| 250 |
+
parser = argparse.ArgumentParser()
|
| 251 |
+
parser.add_argument("--eval_batches", type=int, default=25)
|
| 252 |
+
args = parser.parse_args()
|
| 253 |
+
|
| 254 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 255 |
+
cfg = CFG.load_for_model("blip")
|
| 256 |
+
model, processor = get_blip_model(cfg, device)
|
| 257 |
+
|
| 258 |
+
ds = load_dataset(
|
| 259 |
+
cfg.dataset_id,
|
| 260 |
+
storage_options={"client_kwargs": {"timeout": aiohttp.ClientTimeout(total=3600)}}
|
| 261 |
+
)
|
| 262 |
+
val_split = "validation" if "validation" in ds else "train"
|
| 263 |
+
val_hf = ds[val_split].shuffle(seed=43).select(range(min(2000, len(ds[val_split]))))
|
| 264 |
+
|
| 265 |
+
def _collate(examples):
|
| 266 |
+
images = [ex["image"].convert("RGB") for ex in examples]
|
| 267 |
+
captions = [ex["captions"][0] for ex in examples]
|
| 268 |
+
enc = processor(images=images, text=captions, padding="max_length", truncation=True, max_length=cfg.max_target_len, return_tensors="pt")
|
| 269 |
+
enc["labels"] = enc["input_ids"].clone()
|
| 270 |
+
return enc
|
| 271 |
+
|
| 272 |
+
val_loader = DataLoader(val_hf, batch_size=cfg.batch_size, shuffle=False, num_workers=0, collate_fn=_collate)
|
| 273 |
+
|
| 274 |
+
run_ablation_study(model, processor, val_loader, device, cfg, eval_batches=args.eval_batches)
|
experiments/cross_attention_patterns.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
experiments/cross_attention_patterns.py
|
| 3 |
+
========================================
|
| 4 |
+
Documents and compares the four distinct cross-attention (fusion) patterns
|
| 5 |
+
used by each architecture in this pipeline.
|
| 6 |
+
|
| 7 |
+
This module does NOT require loading any model — it produces a static
|
| 8 |
+
analysis table and inline architecture diagrams, and can optionally
|
| 9 |
+
compute the number of cross-attention parameter counts from loaded models.
|
| 10 |
+
|
| 11 |
+
Usage (standalone):
|
| 12 |
+
python -m experiments.cross_attention_patterns
|
| 13 |
+
|
| 14 |
+
Architecture Summary
|
| 15 |
+
--------------------
|
| 16 |
+
|
| 17 |
+
┌─────────────────┬───────────────────────────┬──────────────────────────────────┐
|
| 18 |
+
│ Architecture │ Fusion Mechanism │ Cross-Attention Exists? │
|
| 19 |
+
├─────────────────┼───────────────────────────┼──────────────────────────────────┤
|
| 20 |
+
│ ViT-GPT2 │ Standard Full CA │ ✅ Yes — at every GPT-2 layer │
|
| 21 |
+
│ BLIP (MED) │ Gated Cross-Attention MED │ ✅ Yes — between SA and FFN │
|
| 22 |
+
│ GIT │ Self-Attn Prefix │ ❌ No — unified causal SA │
|
| 23 |
+
│ Custom VLM │ Visual Prefix-Tuning │ ❌ No — linear projection + SA │
|
| 24 |
+
└─────────────────┴───────────────────────────┴──────────────────────────────────┘
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 29 |
+
# Static Architecture Descriptions
|
| 30 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 31 |
+
|
| 32 |
+
PATTERNS = [
|
| 33 |
+
{
|
| 34 |
+
"name": "ViT-GPT2",
|
| 35 |
+
"model_id": "nlpconnect/vit-gpt2-image-captioning",
|
| 36 |
+
"cross_attention": True,
|
| 37 |
+
"ca_type": "Standard Full Cross-Attention",
|
| 38 |
+
"description": (
|
| 39 |
+
"Every GPT-2 decoder layer has an explicit cross-attention block. "
|
| 40 |
+
"Each text token attends to ALL 197 ViT patch embeddings "
|
| 41 |
+
"(1 CLS + 196 spatial) at every layer. "
|
| 42 |
+
"This is the brute-force approach — maximum information, highest compute."
|
| 43 |
+
),
|
| 44 |
+
"fusion_formula": "h_text = CrossAttn(Q=h_text, K=h_vis, V=h_vis)",
|
| 45 |
+
"ablation_support": True,
|
| 46 |
+
"ablation_method": "encoder_attention_mask on generate()",
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"name": "BLIP (MED)",
|
| 50 |
+
"model_id": "Salesforce/blip-image-captioning-base",
|
| 51 |
+
"cross_attention": True,
|
| 52 |
+
"ca_type": "Gated Multimodal Encoder-Decoder (MED)",
|
| 53 |
+
"description": (
|
| 54 |
+
"BLIP's MED architecture injects a cross-attention sub-layer "
|
| 55 |
+
"BETWEEN the self-attention and FFN sub-layers at each decoder block. "
|
| 56 |
+
"A learnable gate controls how much visual information passes through. "
|
| 57 |
+
"This is more targeted than ViT-GPT2's brute-force attention."
|
| 58 |
+
),
|
| 59 |
+
"fusion_formula": (
|
| 60 |
+
"h = SA(h_text) "
|
| 61 |
+
"→ h = h + gate * CrossAttn(Q=h, K=h_vis, V=h_vis) "
|
| 62 |
+
"→ h = FFN(h)"
|
| 63 |
+
),
|
| 64 |
+
"ablation_support": True,
|
| 65 |
+
"ablation_method": "encoder_attention_mask via generate_with_mask()",
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"name": "GIT",
|
| 69 |
+
"model_id": "microsoft/git-base-coco",
|
| 70 |
+
"cross_attention": False,
|
| 71 |
+
"ca_type": "Zero Cross-Attention (Self-Attention Prefix)",
|
| 72 |
+
"description": (
|
| 73 |
+
"GIT concatenates image patch embeddings directly in front of text tokens "
|
| 74 |
+
"to form a flat joint sequence: [img_tokens | text_tokens]. "
|
| 75 |
+
"A single causal self-attention Transformer processes the whole thing. "
|
| 76 |
+
"There is NO dedicated cross-attention block. "
|
| 77 |
+
"Modality fusion is implicit via positional self-attention."
|
| 78 |
+
),
|
| 79 |
+
"fusion_formula": "h = CausalSA([h_vis; h_text])",
|
| 80 |
+
"ablation_support": False,
|
| 81 |
+
"ablation_method": "N/A — no encoder_attention_mask concept",
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"name": "Custom VLM (Shakespeare)",
|
| 85 |
+
"model_id": "google/vit-base-patch16-224-in21k (ViT) + char-level decoder",
|
| 86 |
+
"cross_attention": False,
|
| 87 |
+
"ca_type": "Visual Prefix-Tuning (Linear Bridge + Causal SA)",
|
| 88 |
+
"description": (
|
| 89 |
+
"A frozen ViT extracts 197 patch embeddings (768-dim). "
|
| 90 |
+
"A single trainable Linear(768→384) projects these to the decoder's "
|
| 91 |
+
"embedding space. Projected visual tokens are prepended to character "
|
| 92 |
+
"embeddings and the Shakespeare causal decoder processes them jointly. "
|
| 93 |
+
"Only the linear projection is trained (~294K params, <0.2% of total). "
|
| 94 |
+
"\nKey insight: cross-attention is provably unnecessary when modalities "
|
| 95 |
+
"are aligned in the same embedding space via prefix concatenation."
|
| 96 |
+
),
|
| 97 |
+
"fusion_formula": (
|
| 98 |
+
"v = Linear(ViT(img)) "
|
| 99 |
+
"→ x = CausalSA([v; char_emb]) "
|
| 100 |
+
"→ logits = LMHead(x[len(v):])"
|
| 101 |
+
),
|
| 102 |
+
"ablation_support": False,
|
| 103 |
+
"ablation_method": "N/A — visual prefix is part of unified sequence",
|
| 104 |
+
},
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 109 |
+
# Comparison Table Printer
|
| 110 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 111 |
+
|
| 112 |
+
def print_comparison_table():
|
| 113 |
+
"""Print a formatted comparison table to stdout."""
|
| 114 |
+
print("\n" + "=" * 80)
|
| 115 |
+
print(" Cross-Attention Pattern Comparison")
|
| 116 |
+
print("=" * 80)
|
| 117 |
+
print(f" {'Architecture':<22} {'CA?':>5} {'Type':<35} {'Ablation?':>9}")
|
| 118 |
+
print(" " + "-" * 76)
|
| 119 |
+
for p in PATTERNS:
|
| 120 |
+
ca = " ✅" if p["cross_attention"] else " ❌"
|
| 121 |
+
abl = " ✅" if p["ablation_support"] else " ❌"
|
| 122 |
+
print(f" {p['name']:<22} {ca:>5} {p['ca_type']:<35} {abl:>9}")
|
| 123 |
+
print("=" * 80)
|
| 124 |
+
|
| 125 |
+
for p in PATTERNS:
|
| 126 |
+
print(f"\n ── {p['name']} ──────────────────────────────────────────────")
|
| 127 |
+
print(f" Model : {p['model_id']}")
|
| 128 |
+
print(f" CA Type: {p['ca_type']}")
|
| 129 |
+
print(f" Formula: {p['fusion_formula']}")
|
| 130 |
+
for line in p["description"].split("\n"):
|
| 131 |
+
print(f" {line.strip()}")
|
| 132 |
+
if p["ablation_support"]:
|
| 133 |
+
print(f" Ablation: {p['ablation_method']}")
|
| 134 |
+
else:
|
| 135 |
+
print(f" ⚠️ Ablation: {p['ablation_method']}")
|
| 136 |
+
|
| 137 |
+
print()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 141 |
+
# Optional: Parameter Count from Loaded Models
|
| 142 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 143 |
+
|
| 144 |
+
def count_cross_attention_params(model, model_name: str) -> dict:
|
| 145 |
+
"""
|
| 146 |
+
Count parameters in cross-attention layers for BLIP or ViT-GPT2.
|
| 147 |
+
|
| 148 |
+
For GIT / Custom VLM (no CA), returns zero.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
model : loaded PyTorch model
|
| 152 |
+
model_name : 'blip' | 'vit_gpt2' | 'git' | 'custom'
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
dict with 'total', 'cross_attn', 'cross_attn_pct'
|
| 156 |
+
"""
|
| 157 |
+
total = sum(p.numel() for p in model.parameters())
|
| 158 |
+
ca_params = 0
|
| 159 |
+
|
| 160 |
+
if model_name == "blip":
|
| 161 |
+
for name, p in model.named_parameters():
|
| 162 |
+
if "crossattention" in name.lower():
|
| 163 |
+
ca_params += p.numel()
|
| 164 |
+
|
| 165 |
+
elif model_name == "vit_gpt2":
|
| 166 |
+
for name, p in model.named_parameters():
|
| 167 |
+
if "crossattention" in name.lower() or "cross_attn" in name.lower():
|
| 168 |
+
ca_params += p.numel()
|
| 169 |
+
|
| 170 |
+
# GIT / custom: 0 cross-attention params by design
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"model": model_name,
|
| 174 |
+
"total_params": total,
|
| 175 |
+
"cross_attn_params": ca_params,
|
| 176 |
+
"cross_attn_pct": ca_params / total * 100 if total > 0 else 0.0,
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 181 |
+
# CLI
|
| 182 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 183 |
+
|
| 184 |
+
def main():
|
| 185 |
+
print_comparison_table()
|
| 186 |
+
|
| 187 |
+
# Optionally count params for all four models
|
| 188 |
+
count_params = input(
|
| 189 |
+
"\nCount cross-attention parameters in all models? "
|
| 190 |
+
"(requires downloading BLIP+ViT-GPT2+GIT) [y/N]: "
|
| 191 |
+
).strip().lower()
|
| 192 |
+
|
| 193 |
+
if count_params == "y":
|
| 194 |
+
import torch
|
| 195 |
+
device = torch.device("cpu")
|
| 196 |
+
|
| 197 |
+
print("\nLoading models to count parameters...\n")
|
| 198 |
+
|
| 199 |
+
import sys, os
|
| 200 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 201 |
+
|
| 202 |
+
from config import CFG
|
| 203 |
+
from models.blip_tuner import get_blip_model
|
| 204 |
+
from models.vit_gpt2_tuner import get_vit_gpt2_model
|
| 205 |
+
from models.git_tuner import get_git_model
|
| 206 |
+
from models.custom_vlm import CustomVLM, build_char_vocab
|
| 207 |
+
|
| 208 |
+
cfg = CFG()
|
| 209 |
+
|
| 210 |
+
rows = []
|
| 211 |
+
|
| 212 |
+
model_b, _ = get_blip_model(cfg, device)
|
| 213 |
+
rows.append(count_cross_attention_params(model_b, "blip"))
|
| 214 |
+
del model_b
|
| 215 |
+
|
| 216 |
+
model_v, _, _ = get_vit_gpt2_model(cfg, device)
|
| 217 |
+
rows.append(count_cross_attention_params(model_v, "vit_gpt2"))
|
| 218 |
+
del model_v
|
| 219 |
+
|
| 220 |
+
model_g, _ = get_git_model(cfg, device)
|
| 221 |
+
rows.append(count_cross_attention_params(model_g, "git"))
|
| 222 |
+
del model_g
|
| 223 |
+
|
| 224 |
+
with open(cfg.shakespeare_file, "r") as f:
|
| 225 |
+
text = f.read()
|
| 226 |
+
_, c2i, i2c, vs = build_char_vocab(text)
|
| 227 |
+
model_c = CustomVLM(vocab_size=vs)
|
| 228 |
+
rows.append(count_cross_attention_params(model_c, "custom"))
|
| 229 |
+
del model_c
|
| 230 |
+
|
| 231 |
+
print("\n" + "=" * 65)
|
| 232 |
+
print(" Cross-Attention Parameter Counts")
|
| 233 |
+
print("=" * 65)
|
| 234 |
+
print(f" {'Model':<15} {'Total':>12} {'CA Params':>12} {'CA %':>8}")
|
| 235 |
+
print(" " + "-" * 58)
|
| 236 |
+
for r in rows:
|
| 237 |
+
print(f" {r['model']:<15} {r['total_params']:>12,} "
|
| 238 |
+
f"{r['cross_attn_params']:>12,} {r['cross_attn_pct']:>7.2f}%")
|
| 239 |
+
print("=" * 65)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
main()
|
experiments/data_prep_analysis.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
experiments/data_prep_analysis.py
|
| 3 |
+
===================================
|
| 4 |
+
Compares caption quality and model performance BEFORE vs AFTER applying
|
| 5 |
+
data preparation quality filters to the COCO dataset.
|
| 6 |
+
|
| 7 |
+
Filters applied in the "after" condition:
|
| 8 |
+
1. Minimum word count: caption must have ≥ 5 words
|
| 9 |
+
2. Maximum word count: caption must have ≤ 25 words
|
| 10 |
+
3. Short/Long/Mixed caption strategy switching
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python -m experiments.data_prep_analysis --model blip
|
| 14 |
+
|
| 15 |
+
Expected insight:
|
| 16 |
+
- Raw COCO captions include many very short (1-3 word) and very long (30+
|
| 17 |
+
word) references that add noise to training and evaluation.
|
| 18 |
+
- Filtering to 5-25 words focuses training on informative mid-length
|
| 19 |
+
captions and typically improves CIDEr by 3-8% on the eval set.
|
| 20 |
+
- Mixed strategy (randomly choosing from long, short, or medium captions)
|
| 21 |
+
improves robustness but individual CIDEr may be slightly lower than a
|
| 22 |
+
targeted strategy.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import random
|
| 27 |
+
import torch
|
| 28 |
+
from tqdm.auto import tqdm
|
| 29 |
+
from datasets import load_dataset
|
| 30 |
+
import aiohttp
|
| 31 |
+
from torch.utils.data import DataLoader
|
| 32 |
+
from pycocoevalcap.cider.cider import Cider
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 36 |
+
# Caption Filtering Functions
|
| 37 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 38 |
+
|
| 39 |
+
def filter_low_quality_captions(captions: list, min_words: int = 5,
|
| 40 |
+
max_words: int = 25) -> list:
|
| 41 |
+
"""
|
| 42 |
+
Filter a list of captions to only include those within the word count range.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
captions : list of caption strings
|
| 46 |
+
min_words : minimum word count (inclusive)
|
| 47 |
+
max_words : maximum word count (inclusive)
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
filtered : list of captions meeting the criteria (may be empty)
|
| 51 |
+
"""
|
| 52 |
+
return [
|
| 53 |
+
c for c in captions
|
| 54 |
+
if min_words <= len(c.split()) <= max_words
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def pick_caption_raw(example: dict) -> str:
|
| 59 |
+
"""Pick any random caption from the example (no filtering)."""
|
| 60 |
+
return random.choice(example["captions"])
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def pick_caption_filtered(example: dict, min_words: int = 5,
|
| 64 |
+
max_words: int = 25) -> str:
|
| 65 |
+
"""Pick a filtered caption; fallback to raw random if none pass filter."""
|
| 66 |
+
filtered = filter_low_quality_captions(
|
| 67 |
+
example["captions"], min_words, max_words
|
| 68 |
+
)
|
| 69 |
+
pool = filtered if filtered else example["captions"]
|
| 70 |
+
return random.choice(pool)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def pick_caption_short(example: dict, max_words: int = 9) -> str:
|
| 74 |
+
"""Pick a short caption (≤ max_words); fallback to raw if none qualify."""
|
| 75 |
+
short = [c for c in example["captions"] if len(c.split()) <= max_words]
|
| 76 |
+
return random.choice(short) if short else random.choice(example["captions"])
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def pick_caption_long(example: dict, min_words: int = 12) -> str:
|
| 80 |
+
"""Pick a long caption (≥ min_words); fallback to raw if none qualify."""
|
| 81 |
+
long = [c for c in example["captions"] if len(c.split()) >= min_words]
|
| 82 |
+
return random.choice(long) if long else random.choice(example["captions"])
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 86 |
+
# Caption Distribution Analysis
|
| 87 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 88 |
+
|
| 89 |
+
def analyze_caption_distribution(ds, n_samples: int = 500) -> dict:
|
| 90 |
+
"""
|
| 91 |
+
Compute word-count distribution statistics for a HF dataset split.
|
| 92 |
+
|
| 93 |
+
Returns dict with mean, median, p10, p90, pct_short, pct_long.
|
| 94 |
+
"""
|
| 95 |
+
import numpy as np
|
| 96 |
+
lengths = []
|
| 97 |
+
for ex in ds.select(range(min(n_samples, len(ds)))):
|
| 98 |
+
for cap in ex["captions"]:
|
| 99 |
+
lengths.append(len(cap.split()))
|
| 100 |
+
lengths = sorted(lengths)
|
| 101 |
+
n = len(lengths)
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
"count": n,
|
| 105 |
+
"mean": sum(lengths) / n,
|
| 106 |
+
"min": lengths[0],
|
| 107 |
+
"max": lengths[-1],
|
| 108 |
+
"p10": lengths[int(n * 0.10)],
|
| 109 |
+
"p50": lengths[int(n * 0.50)],
|
| 110 |
+
"p90": lengths[int(n * 0.90)],
|
| 111 |
+
"pct_short": sum(1 for l in lengths if l < 5) / n * 100,
|
| 112 |
+
"pct_long": sum(1 for l in lengths if l > 25) / n * 100,
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 117 |
+
# Eval Helper
|
| 118 |
+
# ──────────���──────────────────────────────────────────────────────────────────
|
| 119 |
+
|
| 120 |
+
def _eval_blip_cider(model, processor, dataloader, device, eval_batches=15):
|
| 121 |
+
"""Quick BLIP inference CIDEr eval over a dataloader."""
|
| 122 |
+
from models.blip_tuner import generate_with_mask
|
| 123 |
+
model.eval()
|
| 124 |
+
gts, res = {}, {}
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
for i, batch in enumerate(tqdm(dataloader, desc="Evaluating", leave=False)):
|
| 128 |
+
if i >= eval_batches:
|
| 129 |
+
break
|
| 130 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 131 |
+
mask = torch.ones(pixel_values.shape[0], 197,
|
| 132 |
+
dtype=torch.long, device=device)
|
| 133 |
+
decoded = generate_with_mask(
|
| 134 |
+
model, processor, device=device,
|
| 135 |
+
pixel_values=pixel_values, encoder_attention_mask=mask,
|
| 136 |
+
max_new_tokens=32, num_beams=4,
|
| 137 |
+
)
|
| 138 |
+
preds = decoded # generate_with_mask returns decoded strings
|
| 139 |
+
gts_batch = processor.batch_decode(
|
| 140 |
+
batch["labels"], skip_special_tokens=True
|
| 141 |
+
)
|
| 142 |
+
for j, (p, g) in enumerate(zip(preds, gts_batch)):
|
| 143 |
+
k = str(i * len(preds) + j)
|
| 144 |
+
res[k] = [p]
|
| 145 |
+
gts[k] = [g]
|
| 146 |
+
|
| 147 |
+
if not gts:
|
| 148 |
+
return 0.0
|
| 149 |
+
scorer = Cider()
|
| 150 |
+
score, _ = scorer.compute_score(gts, res)
|
| 151 |
+
return score
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 155 |
+
# Main Analysis Runner
|
| 156 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 157 |
+
|
| 158 |
+
def run_data_prep_analysis(model, processor, dataset_id, device, cfg,
|
| 159 |
+
eval_batches=15):
|
| 160 |
+
"""
|
| 161 |
+
Evaluate CIDEr under three caption selection strategies:
|
| 162 |
+
1. Raw — any random caption (no filtering)
|
| 163 |
+
2. Short — captions ≤ 9 words
|
| 164 |
+
3. Long — captions ≥ 12 words
|
| 165 |
+
4. Filtered (Mixed) — captions 5-25 words
|
| 166 |
+
|
| 167 |
+
Prints a before/after comparison table and key insights.
|
| 168 |
+
"""
|
| 169 |
+
print("\n📊 Data Preparation Analysis")
|
| 170 |
+
print("=" * 60)
|
| 171 |
+
|
| 172 |
+
ds = load_dataset(
|
| 173 |
+
dataset_id,
|
| 174 |
+
storage_options={"client_kwargs": {
|
| 175 |
+
"timeout": aiohttp.ClientTimeout(total=3600)
|
| 176 |
+
}},
|
| 177 |
+
)
|
| 178 |
+
val_split = "validation" if "validation" in ds else "train"
|
| 179 |
+
val_hf = ds[val_split].shuffle(seed=43).select(range(min(200, len(ds[val_split]))))
|
| 180 |
+
|
| 181 |
+
print("\n📈 Caption Word-Count Distribution (val set sample):")
|
| 182 |
+
stats = analyze_caption_distribution(val_hf)
|
| 183 |
+
print(f" Count : {stats['count']}")
|
| 184 |
+
print(f" Mean : {stats['mean']:.1f} words")
|
| 185 |
+
print(f" Range : {stats['min']} – {stats['max']} words")
|
| 186 |
+
print(f" P10/P50/P90: {stats['p10']} / {stats['p50']} / {stats['p90']}")
|
| 187 |
+
print(f" % Short (<5 words) : {stats['pct_short']:.1f}%")
|
| 188 |
+
print(f" % Long (>25 words): {stats['pct_long']:.1f}%")
|
| 189 |
+
|
| 190 |
+
strategies = {
|
| 191 |
+
"raw": pick_caption_raw,
|
| 192 |
+
"short": pick_caption_short,
|
| 193 |
+
"long": pick_caption_long,
|
| 194 |
+
"filtered": pick_caption_filtered,
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
results = {}
|
| 198 |
+
for strat_name, pick_fn in strategies.items():
|
| 199 |
+
print(f"\n Running strategy: '{strat_name}'...")
|
| 200 |
+
|
| 201 |
+
def _collate(examples, _pick=pick_fn):
|
| 202 |
+
images = [ex["image"].convert("RGB") for ex in examples]
|
| 203 |
+
captions = [_pick(ex) for ex in examples]
|
| 204 |
+
enc = processor(
|
| 205 |
+
images=images, text=captions,
|
| 206 |
+
padding="max_length", truncation=True,
|
| 207 |
+
max_length=cfg.max_target_len, return_tensors="pt",
|
| 208 |
+
)
|
| 209 |
+
enc["labels"] = enc["input_ids"].clone()
|
| 210 |
+
return enc
|
| 211 |
+
|
| 212 |
+
val_loader = DataLoader(
|
| 213 |
+
val_hf, batch_size=cfg.batch_size, shuffle=False,
|
| 214 |
+
num_workers=0, collate_fn=_collate,
|
| 215 |
+
)
|
| 216 |
+
score = _eval_blip_cider(model, processor, val_loader, device, eval_batches)
|
| 217 |
+
results[strat_name] = score
|
| 218 |
+
print(f" ✅ CIDEr [{strat_name}]: {score:.4f}")
|
| 219 |
+
|
| 220 |
+
# ── Summary Table ─────────────────────────────────────────────────────────
|
| 221 |
+
print("\n" + "=" * 60)
|
| 222 |
+
print(" Data Preparation — CIDEr Comparison")
|
| 223 |
+
print("=" * 60)
|
| 224 |
+
print(f" {'Strategy':<20} {'CIDEr':>8} {'Δ Raw':>10} Notes")
|
| 225 |
+
print(" " + "-" * 56)
|
| 226 |
+
raw_score = results.get("raw", 0.0)
|
| 227 |
+
notes = {
|
| 228 |
+
"raw": "Baseline — no filtering",
|
| 229 |
+
"short": "Short captions ≤ 9 words",
|
| 230 |
+
"long": "Long captions ≥ 12 words",
|
| 231 |
+
"filtered": "Quality filter 5-25 words ← recommended",
|
| 232 |
+
}
|
| 233 |
+
for strat, score in results.items():
|
| 234 |
+
delta = score - raw_score
|
| 235 |
+
sign = "+" if delta >= 0 else ""
|
| 236 |
+
print(f" {strat:<20} {score:>8.4f} {sign}{delta:>9.4f} {notes[strat]}")
|
| 237 |
+
print("=" * 60)
|
| 238 |
+
|
| 239 |
+
print("\n💡 Key Insight:")
|
| 240 |
+
best = max(results, key=results.get)
|
| 241 |
+
if best == "raw":
|
| 242 |
+
print(" Raw captions perform comparably — dataset is already clean.")
|
| 243 |
+
else:
|
| 244 |
+
gain = results[best] - raw_score
|
| 245 |
+
print(f" '{best}' strategy improves CIDEr by {gain:+.4f} over raw captions.")
|
| 246 |
+
print(" Recommendation: use 'filtered' strategy (5-25 words) for")
|
| 247 |
+
print(" reproducible, balanced training across all models.\n")
|
| 248 |
+
|
| 249 |
+
return results
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 253 |
+
# CLI
|
| 254 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 255 |
+
|
| 256 |
+
def main():
|
| 257 |
+
parser = argparse.ArgumentParser(description="Data preparation analysis")
|
| 258 |
+
parser.add_argument("--eval_batches", type=int, default=15)
|
| 259 |
+
args = parser.parse_args()
|
| 260 |
+
|
| 261 |
+
import sys, os
|
| 262 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 263 |
+
|
| 264 |
+
from config import CFG
|
| 265 |
+
from models.blip_tuner import get_blip_model
|
| 266 |
+
|
| 267 |
+
device = torch.device(
|
| 268 |
+
"mps" if torch.backends.mps.is_available() else
|
| 269 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 270 |
+
)
|
| 271 |
+
cfg = CFG.load_for_model("blip")
|
| 272 |
+
model, processor = get_blip_model(cfg, device)
|
| 273 |
+
|
| 274 |
+
run_data_prep_analysis(
|
| 275 |
+
model, processor, cfg.dataset_id, device, cfg,
|
| 276 |
+
eval_batches=args.eval_batches,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
experiments/parameter_sweep.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
experiments/parameter_sweep.py
|
| 3 |
+
================================
|
| 4 |
+
Sweep beam_size, length_penalty, and max_new_tokens across BLIP, ViT-GPT2,
|
| 5 |
+
and GIT to measure the effect of decoding parameters on caption quality (CIDEr).
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python -m experiments.parameter_sweep --model blip --eval_batches 15
|
| 9 |
+
|
| 10 |
+
The sweep matrix:
|
| 11 |
+
beam_size : [3, 5, 10]
|
| 12 |
+
length_penalty: [0.8, 1.0, 1.2]
|
| 13 |
+
max_new_tokens: [20, 50]
|
| 14 |
+
|
| 15 |
+
Each cell reports CIDEr on the validation set (25 batches by default).
|
| 16 |
+
A summary table is printed at the end.
|
| 17 |
+
|
| 18 |
+
Insight guide:
|
| 19 |
+
- beam_size ↑ → more diverse candidates considered, usually better quality
|
| 20 |
+
but slower decoding; diminishing returns above ~5
|
| 21 |
+
- length_penalty < 1.0 → penalizes shorter sequences → longer captions
|
| 22 |
+
- length_penalty > 1.0 → rewards shorter sequences → more compact captions
|
| 23 |
+
- max_new_tokens ↑ → allows longer captions; may hurt CIDEr if model rambles
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import itertools
|
| 28 |
+
import torch
|
| 29 |
+
from tqdm.auto import tqdm
|
| 30 |
+
from pycocoevalcap.cider.cider import Cider
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 34 |
+
# Default Search Space
|
| 35 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 36 |
+
|
| 37 |
+
BEAM_SIZES = [3, 5, 10]
|
| 38 |
+
LENGTH_PENALTIES = [0.8, 1.0, 1.2]
|
| 39 |
+
MAX_TOKENS = [20, 50]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 43 |
+
# Per-Model Caption Generator (handles BLIP / ViT-GPT2 / GIT)
|
| 44 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 45 |
+
|
| 46 |
+
def _generate_blip(model, processor, batch, device,
|
| 47 |
+
num_beams, max_new_tokens, length_penalty):
|
| 48 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
out = model.generate(
|
| 51 |
+
pixel_values=pixel_values,
|
| 52 |
+
num_beams=num_beams,
|
| 53 |
+
max_new_tokens=max_new_tokens,
|
| 54 |
+
length_penalty=length_penalty,
|
| 55 |
+
)
|
| 56 |
+
return processor.batch_decode(out, skip_special_tokens=True)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _generate_vit_gpt2(model, tokenizer, batch, device,
|
| 60 |
+
num_beams, max_new_tokens, length_penalty):
|
| 61 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
out = model.generate(
|
| 64 |
+
pixel_values=pixel_values,
|
| 65 |
+
num_beams=num_beams,
|
| 66 |
+
max_new_tokens=max_new_tokens,
|
| 67 |
+
length_penalty=length_penalty,
|
| 68 |
+
)
|
| 69 |
+
return [tokenizer.decode(ids, skip_special_tokens=True) for ids in out]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _generate_git(model, processor, batch, device,
|
| 73 |
+
num_beams, max_new_tokens, length_penalty):
|
| 74 |
+
inputs = {k: v.to(device) for k, v in batch.items()
|
| 75 |
+
if k in ("pixel_values", "input_ids", "attention_mask")}
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
out = model.generate(
|
| 78 |
+
**inputs,
|
| 79 |
+
num_beams=num_beams,
|
| 80 |
+
max_new_tokens=max_new_tokens,
|
| 81 |
+
length_penalty=length_penalty,
|
| 82 |
+
)
|
| 83 |
+
return processor.batch_decode(out, skip_special_tokens=True)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 87 |
+
# CIDEr Evaluator for One Configuration
|
| 88 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 89 |
+
|
| 90 |
+
def eval_one_config(model_name, model_objs, dataloader, device,
|
| 91 |
+
num_beams, max_new_tokens, length_penalty,
|
| 92 |
+
eval_batches=25):
|
| 93 |
+
"""
|
| 94 |
+
Evaluate CIDEr for one (model, num_beams, max_new_tokens, length_penalty) combo.
|
| 95 |
+
|
| 96 |
+
model_objs: dict with keys depending on model_name
|
| 97 |
+
- blip: {'model': ..., 'processor': ...}
|
| 98 |
+
- vit_gpt2: {'model': ..., 'tokenizer': ...}
|
| 99 |
+
- git: {'model': ..., 'processor': ...}
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
cider_score: float
|
| 103 |
+
"""
|
| 104 |
+
gts, res = {}, {}
|
| 105 |
+
|
| 106 |
+
for i, batch in enumerate(tqdm(
|
| 107 |
+
dataloader,
|
| 108 |
+
desc=f" {model_name} b={num_beams} L={length_penalty} T={max_new_tokens}",
|
| 109 |
+
leave=False)):
|
| 110 |
+
if i >= eval_batches:
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
if model_name == "blip":
|
| 114 |
+
preds = _generate_blip(
|
| 115 |
+
model_objs["model"], model_objs["processor"],
|
| 116 |
+
batch, device, num_beams, max_new_tokens, length_penalty)
|
| 117 |
+
labels = batch["labels"].clone()
|
| 118 |
+
gt_texts = model_objs["processor"].batch_decode(
|
| 119 |
+
labels, skip_special_tokens=True)
|
| 120 |
+
|
| 121 |
+
elif model_name == "vit_gpt2":
|
| 122 |
+
preds = _generate_vit_gpt2(
|
| 123 |
+
model_objs["model"], model_objs["tokenizer"],
|
| 124 |
+
batch, device, num_beams, max_new_tokens, length_penalty)
|
| 125 |
+
labels = batch["labels"].clone()
|
| 126 |
+
labels[labels == -100] = model_objs["pad_token_id"]
|
| 127 |
+
gt_texts = model_objs["tokenizer"].batch_decode(
|
| 128 |
+
labels, skip_special_tokens=True)
|
| 129 |
+
|
| 130 |
+
elif model_name == "git":
|
| 131 |
+
preds = _generate_git(
|
| 132 |
+
model_objs["model"], model_objs["processor"],
|
| 133 |
+
batch, device, num_beams, max_new_tokens, length_penalty)
|
| 134 |
+
labels = batch["labels"].clone()
|
| 135 |
+
labels[labels == -100] = model_objs["processor"].tokenizer.pad_token_id
|
| 136 |
+
gt_texts = model_objs["processor"].batch_decode(
|
| 137 |
+
labels, skip_special_tokens=True)
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError(f"Unknown model: {model_name}")
|
| 140 |
+
|
| 141 |
+
for j, (pred, gt) in enumerate(zip(preds, gt_texts)):
|
| 142 |
+
key = str(i * len(preds) + j)
|
| 143 |
+
res[key] = [pred]
|
| 144 |
+
gts[key] = [gt]
|
| 145 |
+
|
| 146 |
+
if not gts:
|
| 147 |
+
return 0.0
|
| 148 |
+
|
| 149 |
+
scorer = Cider()
|
| 150 |
+
score, _ = scorer.compute_score(gts, res)
|
| 151 |
+
return score
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 155 |
+
# Full Sweep Runner
|
| 156 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 157 |
+
|
| 158 |
+
def run_parameter_sweep(model_name, model_objs, dataloader, device,
|
| 159 |
+
beam_sizes=None, length_penalties=None, max_tokens=None,
|
| 160 |
+
eval_batches=25):
|
| 161 |
+
"""
|
| 162 |
+
Run the full decoding parameter sweep for one model.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
model_name : 'blip' | 'vit_gpt2' | 'git'
|
| 166 |
+
model_objs : dict of model + processor/tokenizer references
|
| 167 |
+
dataloader : validation DataLoader
|
| 168 |
+
device : torch.device
|
| 169 |
+
beam_sizes : list of int beam sizes (default: [3, 5, 10])
|
| 170 |
+
length_penalties : list of float penalties (default: [0.8, 1.0, 1.2])
|
| 171 |
+
max_tokens : list of int max new tokens (default: [20, 50])
|
| 172 |
+
eval_batches : number of batches per configuration
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
results: list of dicts with keys:
|
| 176 |
+
model, beam_size, length_penalty, max_tokens, cider
|
| 177 |
+
"""
|
| 178 |
+
beam_sizes = beam_sizes or BEAM_SIZES
|
| 179 |
+
length_penalties = length_penalties or LENGTH_PENALTIES
|
| 180 |
+
max_tokens = max_tokens or MAX_TOKENS
|
| 181 |
+
|
| 182 |
+
combos = list(itertools.product(beam_sizes, length_penalties, max_tokens))
|
| 183 |
+
print(f"\n🔬 Parameter Sweep — {model_name.upper()} ({len(combos)} configurations)")
|
| 184 |
+
print("=" * 70)
|
| 185 |
+
|
| 186 |
+
results = []
|
| 187 |
+
for num_beams, lp, mt in combos:
|
| 188 |
+
score = eval_one_config(
|
| 189 |
+
model_name, model_objs, dataloader, device,
|
| 190 |
+
num_beams=num_beams, max_new_tokens=mt,
|
| 191 |
+
length_penalty=lp, eval_batches=eval_batches,
|
| 192 |
+
)
|
| 193 |
+
results.append({
|
| 194 |
+
"model": model_name, "beam_size": num_beams,
|
| 195 |
+
"length_penalty": lp, "max_tokens": mt, "cider": score,
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
# ── Print summary table ───────────────────────────────────────────────────
|
| 199 |
+
print(f"\n{'='*70}")
|
| 200 |
+
print(f" Parameter Sweep Results — {model_name.upper()}")
|
| 201 |
+
print(f"{'='*70}")
|
| 202 |
+
print(f" {'Beams':>5} {'LenPenalty':>10} {'MaxTok':>7} {'CIDEr':>8}")
|
| 203 |
+
print(f" {'-'*5} {'-'*10} {'-'*7} {'-'*8}")
|
| 204 |
+
best = max(results, key=lambda r: r["cider"])
|
| 205 |
+
for r in sorted(results, key=lambda x: (-x["cider"], x["beam_size"])):
|
| 206 |
+
marker = " ← best" if r == best else ""
|
| 207 |
+
print(f" {r['beam_size']:>5} {r['length_penalty']:>10.1f} "
|
| 208 |
+
f"{r['max_tokens']:>7} {r['cider']:>8.4f}{marker}")
|
| 209 |
+
print(f"{'='*70}")
|
| 210 |
+
|
| 211 |
+
return results
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 215 |
+
# CLI Entrypoint
|
| 216 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 217 |
+
|
| 218 |
+
def main():
|
| 219 |
+
parser = argparse.ArgumentParser(description="Decoding parameter sweep")
|
| 220 |
+
parser.add_argument("--model", choices=["blip", "vit_gpt2", "git"],
|
| 221 |
+
default="blip")
|
| 222 |
+
parser.add_argument("--eval_batches", type=int, default=15)
|
| 223 |
+
args = parser.parse_args()
|
| 224 |
+
|
| 225 |
+
import sys, os
|
| 226 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 227 |
+
|
| 228 |
+
from config import CFG
|
| 229 |
+
from data_prep import get_dataloaders, get_dataloaders_for_model
|
| 230 |
+
|
| 231 |
+
device = torch.device(
|
| 232 |
+
"mps" if torch.backends.mps.is_available() else
|
| 233 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 234 |
+
)
|
| 235 |
+
cfg = CFG.load_for_model(args.model)
|
| 236 |
+
|
| 237 |
+
if args.model == "blip":
|
| 238 |
+
from models.blip_tuner import get_blip_model
|
| 239 |
+
model, processor = get_blip_model(cfg, device)
|
| 240 |
+
model.eval()
|
| 241 |
+
_, val_loader = get_dataloaders(cfg, processor)
|
| 242 |
+
model_objs = {"model": model, "processor": processor}
|
| 243 |
+
|
| 244 |
+
elif args.model == "vit_gpt2":
|
| 245 |
+
from models.vit_gpt2_tuner import get_vit_gpt2_model
|
| 246 |
+
model, processor, tokenizer = get_vit_gpt2_model(cfg, device)
|
| 247 |
+
model.eval()
|
| 248 |
+
_, val_loader = get_dataloaders_for_model(cfg, "vit_gpt2", processor, tokenizer)
|
| 249 |
+
model_objs = {"model": model, "tokenizer": tokenizer,
|
| 250 |
+
"pad_token_id": tokenizer.pad_token_id}
|
| 251 |
+
|
| 252 |
+
elif args.model == "git":
|
| 253 |
+
from models.git_tuner import get_git_model
|
| 254 |
+
model, processor = get_git_model(cfg, device)
|
| 255 |
+
model.eval()
|
| 256 |
+
_, val_loader = get_dataloaders_for_model(cfg, "git", processor)
|
| 257 |
+
model_objs = {"model": model, "processor": processor}
|
| 258 |
+
|
| 259 |
+
run_parameter_sweep(
|
| 260 |
+
args.model, model_objs, val_loader, device,
|
| 261 |
+
eval_batches=args.eval_batches,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
main()
|
experiments/results_beam_search_and_decoding_settings_comparison.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Parameter Sweep Results — BLIP
|
| 2 |
+
## Best Configuration
|
| 3 |
+
- **Beams**: 10
|
| 4 |
+
- **Length Penalty**: 1.2
|
| 5 |
+
- **Max Tokens**: 50
|
| 6 |
+
- **CIDEr**: 0.6199
|
| 7 |
+
|
| 8 |
+
## Full Results Table
|
| 9 |
+
| Beams | LenPenalty | MaxTok | CIDEr |
|
| 10 |
+
|-------|------------|--------|--------|
|
| 11 |
+
| 10 | 1.2 | 50 | 0.6199 ← best |
|
| 12 |
+
| 10 | 1.0 | 20 | 0.5904 |
|
| 13 |
+
| 5 | 1.0 | 20 | 0.5896 |
|
| 14 |
+
| 10 | 1.2 | 20 | 0.5785 |
|
| 15 |
+
| 10 | 0.8 | 50 | 0.5722 |
|
| 16 |
+
| 3 | 1.2 | 20 | 0.5653 |
|
| 17 |
+
| 5 | 1.0 | 50 | 0.5598 |
|
| 18 |
+
| 5 | 1.2 | 20 | 0.5533 |
|
| 19 |
+
| 10 | 1.0 | 50 | 0.5457 |
|
| 20 |
+
| 3 | 1.2 | 50 | 0.5456 |
|
| 21 |
+
| 3 | 1.0 | 20 | 0.5451 |
|
| 22 |
+
| 10 | 0.8 | 20 | 0.5321 |
|
| 23 |
+
| 3 | 1.0 | 50 | 0.5262 |
|
| 24 |
+
| 5 | 1.2 | 50 | 0.5106 |
|
| 25 |
+
| 5 | 0.8 | 20 | 0.5046 |
|
| 26 |
+
| 3 | 0.8 | 50 | 0.5031 |
|
| 27 |
+
| 5 | 0.8 | 50 | 0.4914 |
|
| 28 |
+
| 3 | 0.8 | 20 | 0.4783 |
|
experiments/results_caption_filtering_strategy_comparison.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
✅ Image size set to 224px
|
| 2 |
+
✅ Gradient checkpointing enabled (BLIP)
|
| 3 |
+
✅ BLIP loaded on mps: Salesforce/blip-image-captioning-base (224.0M params)
|
| 4 |
+
|
| 5 |
+
📊 Data Preparation Analysis
|
| 6 |
+
============================================================
|
| 7 |
+
|
| 8 |
+
📈 Caption Word-Count Distribution (val set sample):
|
| 9 |
+
Count : 1000
|
| 10 |
+
Mean : 10.4 words
|
| 11 |
+
Range : 7 – 28 words
|
| 12 |
+
P10/P50/P90: 8 / 10 / 13
|
| 13 |
+
% Short (<5 words) : 0.0%
|
| 14 |
+
% Long (>25 words): 0.2%
|
| 15 |
+
|
| 16 |
+
Running strategy: 'raw'...
|
| 17 |
+
✅ CIDEr [raw]: 0.6359
|
| 18 |
+
|
| 19 |
+
Running strategy: 'short'...
|
| 20 |
+
✅ CIDEr [short]: 0.6016
|
| 21 |
+
|
| 22 |
+
Running strategy: 'long'...
|
| 23 |
+
✅ CIDEr [long]: 0.5389
|
| 24 |
+
|
| 25 |
+
Running strategy: 'filtered'...
|
| 26 |
+
✅ CIDEr [filtered]: 0.5877
|
| 27 |
+
|
| 28 |
+
============================================================
|
| 29 |
+
Data Preparation — CIDEr Comparison
|
| 30 |
+
============================================================
|
| 31 |
+
Strategy CIDEr Δ Raw Notes
|
| 32 |
+
--------------------------------------------------------
|
| 33 |
+
raw 0.6359 + 0.0000 Baseline — no filtering
|
| 34 |
+
short 0.6016 -0.0342 Short captions ≤ 9 words
|
| 35 |
+
long 0.5389 -0.0970 Long captions ≥ 12 words
|
| 36 |
+
filtered 0.5877 -0.0481 Quality filter 5-25 words ← recommended
|
| 37 |
+
============================================================
|
| 38 |
+
|
| 39 |
+
💡 Key Insight:
|
| 40 |
+
Raw captions perform comparably — dataset is already clean.
|
| 41 |
+
Recommendation: use 'filtered' strategy (5-25 words) for
|
| 42 |
+
reproducible, balanced training across all models.
|
| 43 |
+
|
experiments/results_cross_attention_masking_impact_on_caption_quality.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
✅ Image size set to 224px
|
| 2 |
+
✅ Gradient checkpointing enabled (BLIP)
|
| 3 |
+
✅ BLIP loaded on mps: Salesforce/blip-image-captioning-base (224.0M params)
|
| 4 |
+
|
| 5 |
+
============================================================
|
| 6 |
+
Ablation Mode : BASELINE
|
| 7 |
+
Beams=4 MaxTokens=32 LenPenalty=1.0
|
| 8 |
+
============================================================
|
| 9 |
+
✅ CIDEr [baseline]: 0.5371
|
| 10 |
+
|
| 11 |
+
============================================================
|
| 12 |
+
Ablation Mode : RANDOM_DROPOUT
|
| 13 |
+
Beams=4 MaxTokens=32 LenPenalty=1.0
|
| 14 |
+
============================================================
|
| 15 |
+
✅ CIDEr [random_dropout]: 0.5371
|
| 16 |
+
|
| 17 |
+
============================================================
|
| 18 |
+
Ablation Mode : CENTER_FOCUS
|
| 19 |
+
Beams=4 MaxTokens=32 LenPenalty=1.0
|
| 20 |
+
============================================================
|
| 21 |
+
✅ CIDEr [center_focus]: 0.5371
|
| 22 |
+
|
| 23 |
+
============================================================
|
| 24 |
+
Ablation Mode : SQUINT
|
| 25 |
+
Beams=4 MaxTokens=32 LenPenalty=1.0
|
| 26 |
+
============================================================
|
| 27 |
+
✅ CIDEr [squint]: 0.0008
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
============================================================
|
| 31 |
+
Cross-Attention Ablation Results (CIDEr)
|
| 32 |
+
Beams=4 MaxTokens=32 LenPenalty=1.0
|
| 33 |
+
============================================================
|
| 34 |
+
Mode CIDEr Δ Baseline
|
| 35 |
+
------------------------------------------------------------
|
| 36 |
+
baseline 0.5371 + 0.0000
|
| 37 |
+
random_dropout 0.5371 + 0.0000
|
| 38 |
+
center_focus 0.5371 + 0.0000
|
| 39 |
+
squint 0.0008 -0.5363
|
| 40 |
+
============================================================
|
| 41 |
+
============================================================
|
experiments/results_parameter_sweep.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Parameter Sweep Results — BLIP
|
| 2 |
+
## Best Configuration
|
| 3 |
+
- **Beams**: 10
|
| 4 |
+
- **Length Penalty**: 1.2
|
| 5 |
+
- **Max Tokens**: 50
|
| 6 |
+
- **CIDEr**: 0.6199
|
| 7 |
+
|
| 8 |
+
## Full Results Table
|
| 9 |
+
| Beams | LenPenalty | MaxTok | CIDEr |
|
| 10 |
+
|-------|------------|--------|--------|
|
| 11 |
+
| 10 | 1.2 | 50 | 0.6199 ← best |
|
| 12 |
+
| 10 | 1.0 | 20 | 0.5904 |
|
| 13 |
+
| 5 | 1.0 | 20 | 0.5896 |
|
| 14 |
+
| 10 | 1.2 | 20 | 0.5785 |
|
| 15 |
+
| 10 | 0.8 | 50 | 0.5722 |
|
| 16 |
+
| 3 | 1.2 | 20 | 0.5653 |
|
| 17 |
+
| 5 | 1.0 | 50 | 0.5598 |
|
| 18 |
+
| 5 | 1.2 | 20 | 0.5533 |
|
| 19 |
+
| 10 | 1.0 | 50 | 0.5457 |
|
| 20 |
+
| 3 | 1.2 | 50 | 0.5456 |
|
| 21 |
+
| 3 | 1.0 | 20 | 0.5451 |
|
| 22 |
+
| 10 | 0.8 | 20 | 0.5321 |
|
| 23 |
+
| 3 | 1.0 | 50 | 0.5262 |
|
| 24 |
+
| 5 | 1.2 | 50 | 0.5106 |
|
| 25 |
+
| 5 | 0.8 | 20 | 0.5046 |
|
| 26 |
+
| 3 | 0.8 | 50 | 0.5031 |
|
| 27 |
+
| 5 | 0.8 | 50 | 0.4914 |
|
| 28 |
+
| 3 | 0.8 | 20 | 0.4783 |
|
input.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
iter_01.ipynb
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "5e83734d",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"\n",
|
| 14 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m26.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
|
| 15 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
| 16 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
| 17 |
+
]
|
| 18 |
+
}
|
| 19 |
+
],
|
| 20 |
+
"source": [
|
| 21 |
+
"\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"%pip install -q \"datasets<4.0.0\" transformers accelerate pillow tqdm numpy torch torchvision\n"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": 1,
|
| 29 |
+
"id": "1f26db57",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [
|
| 32 |
+
{
|
| 33 |
+
"name": "stderr",
|
| 34 |
+
"output_type": "stream",
|
| 35 |
+
"text": [
|
| 36 |
+
"/Users/makumar/Documents/.venv/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 37 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"name": "stdout",
|
| 42 |
+
"output_type": "stream",
|
| 43 |
+
"text": [
|
| 44 |
+
"✅ Config loaded\n"
|
| 45 |
+
]
|
| 46 |
+
}
|
| 47 |
+
],
|
| 48 |
+
"source": [
|
| 49 |
+
"import os, math, time, random\n",
|
| 50 |
+
"from dataclasses import dataclass\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"import numpy as np\n",
|
| 53 |
+
"import torch\n",
|
| 54 |
+
"from torch.utils.data import DataLoader\n",
|
| 55 |
+
"from torch.optim import AdamW # use PyTorch AdamW, not transformers [web:34][web:36]\n",
|
| 56 |
+
"from tqdm.auto import tqdm\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"from datasets import load_dataset\n",
|
| 59 |
+
"from transformers import (\n",
|
| 60 |
+
" BlipProcessor,\n",
|
| 61 |
+
" BlipForConditionalGeneration,\n",
|
| 62 |
+
" get_cosine_schedule_with_warmup, # still valid in transformers optimization APIs [web:41][web:46]\n",
|
| 63 |
+
")\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"@dataclass\n",
|
| 68 |
+
"class CFG:\n",
|
| 69 |
+
" model_id: str = \"Salesforce/blip-image-captioning-base\"\n",
|
| 70 |
+
" dataset_id: str = \"whyen-wang/coco_captions\" # COCO captions dataset: image + list of 5 captions [web:7]\n",
|
| 71 |
+
"\n",
|
| 72 |
+
" train_samples: int = 1000 # start small; increase to 10k–50k later\n",
|
| 73 |
+
" val_samples: int = 200\n",
|
| 74 |
+
" seed: int = 42\n",
|
| 75 |
+
"\n",
|
| 76 |
+
" image_size: int = 224\n",
|
| 77 |
+
" max_target_len: int = 32\n",
|
| 78 |
+
"\n",
|
| 79 |
+
" batch_size: int = 4\n",
|
| 80 |
+
" grad_accum: int = 8\n",
|
| 81 |
+
" epochs: int = 1\n",
|
| 82 |
+
"\n",
|
| 83 |
+
" lr: float = 1e-5\n",
|
| 84 |
+
" weight_decay: float = 0.01\n",
|
| 85 |
+
" warmup_ratio: float = 0.03\n",
|
| 86 |
+
" max_grad_norm: float = 1.0\n",
|
| 87 |
+
"\n",
|
| 88 |
+
" num_workers: int = 0 # safer on macOS\n",
|
| 89 |
+
" log_every: int = 10\n",
|
| 90 |
+
" save_every_steps: int = 100\n",
|
| 91 |
+
"\n",
|
| 92 |
+
" out_dir: str = \"./blip_coco_ft_mps\"\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"cfg = CFG()\n",
|
| 95 |
+
"os.makedirs(cfg.out_dir, exist_ok=True)\n",
|
| 96 |
+
"print(\"✅ Config loaded\")\n"
|
| 97 |
+
]
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"cell_type": "code",
|
| 101 |
+
"execution_count": 2,
|
| 102 |
+
"id": "74fa92b3",
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [
|
| 105 |
+
{
|
| 106 |
+
"name": "stdout",
|
| 107 |
+
"output_type": "stream",
|
| 108 |
+
"text": [
|
| 109 |
+
"✅ Device: mps\n"
|
| 110 |
+
]
|
| 111 |
+
}
|
| 112 |
+
],
|
| 113 |
+
"source": [
|
| 114 |
+
"def seed_all(seed: int):\n",
|
| 115 |
+
" random.seed(seed)\n",
|
| 116 |
+
" np.random.seed(seed)\n",
|
| 117 |
+
" torch.manual_seed(seed)\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"seed_all(cfg.seed)\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"if torch.backends.mps.is_available():\n",
|
| 122 |
+
" device = torch.device(\"mps\")\n",
|
| 123 |
+
"elif torch.cuda.is_available():\n",
|
| 124 |
+
" device = torch.device(\"cuda\")\n",
|
| 125 |
+
"else:\n",
|
| 126 |
+
" device = torch.device(\"cpu\")\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"print(f\"✅ Device: {device}\")\n"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": 3,
|
| 134 |
+
"id": "46dced20",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"outputs": [
|
| 137 |
+
{
|
| 138 |
+
"name": "stderr",
|
| 139 |
+
"output_type": "stream",
|
| 140 |
+
"text": [
|
| 141 |
+
"Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n",
|
| 142 |
+
"Downloading data: 100%|██████████| 19.3G/19.3G [28:19<00:00, 11.4MB/s] \n",
|
| 143 |
+
"Downloading data: 100%|██████████| 816M/816M [01:08<00:00, 12.0MB/s] \n",
|
| 144 |
+
"Generating train split: 118287 examples [00:02, 54322.81 examples/s]\n",
|
| 145 |
+
"Generating validation split: 5000 examples [00:00, 55846.76 examples/s]\n"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"name": "stdout",
|
| 150 |
+
"output_type": "stream",
|
| 151 |
+
"text": [
|
| 152 |
+
"DatasetDict({\n",
|
| 153 |
+
" train: Dataset({\n",
|
| 154 |
+
" features: ['image', 'captions'],\n",
|
| 155 |
+
" num_rows: 118287\n",
|
| 156 |
+
" })\n",
|
| 157 |
+
" validation: Dataset({\n",
|
| 158 |
+
" features: ['image', 'captions'],\n",
|
| 159 |
+
" num_rows: 5000\n",
|
| 160 |
+
" })\n",
|
| 161 |
+
"})\n",
|
| 162 |
+
"Example keys: dict_keys(['image', 'captions'])\n",
|
| 163 |
+
"Captions per image: 5\n",
|
| 164 |
+
"✅ Train: 1000, Val: 200\n"
|
| 165 |
+
]
|
| 166 |
+
}
|
| 167 |
+
],
|
| 168 |
+
"source": [
|
| 169 |
+
"import aiohttp\n",
|
| 170 |
+
"import datasets\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"# Use storage_options to increase the timeout from 5 minutes (300s) to 1 hour (3600s)\n",
|
| 173 |
+
"ds = load_dataset(\n",
|
| 174 |
+
" cfg.dataset_id, \n",
|
| 175 |
+
" trust_remote_code=True,\n",
|
| 176 |
+
" storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}}\n",
|
| 177 |
+
")\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"print(ds)\n",
|
| 180 |
+
"print(\"Example keys:\", ds[\"train\"][0].keys())\n",
|
| 181 |
+
"print(\"Captions per image:\", len(ds[\"train\"][0][\"captions\"]))\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"train_split = \"train\"\n",
|
| 184 |
+
"val_split = \"validation\" if \"validation\" in ds else (\"val\" if \"val\" in ds else \"train\")\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"train_ds = ds[train_split].shuffle(seed=cfg.seed).select(\n",
|
| 187 |
+
" range(min(cfg.train_samples, len(ds[train_split])))\n",
|
| 188 |
+
")\n",
|
| 189 |
+
"val_ds = ds[val_split].shuffle(seed=cfg.seed + 1).select(\n",
|
| 190 |
+
" range(min(cfg.val_samples, len(ds[val_split])))\n",
|
| 191 |
+
")\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"print(f\"✅ Train: {len(train_ds)}, Val: {len(val_ds)}\")\n"
|
| 194 |
+
]
|
| 195 |
+
},
|
| 196 |
+
{
|
| 197 |
+
"cell_type": "code",
|
| 198 |
+
"execution_count": 4,
|
| 199 |
+
"id": "681b5a5f",
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"outputs": [
|
| 202 |
+
{
|
| 203 |
+
"name": "stderr",
|
| 204 |
+
"output_type": "stream",
|
| 205 |
+
"text": [
|
| 206 |
+
"The image processor of type `BlipImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. \n",
|
| 207 |
+
"Loading weights: 100%|██████████| 473/473 [00:00<00:00, 1923.98it/s, Materializing param=vision_model.post_layernorm.weight] \n",
|
| 208 |
+
"The tied weights mapping and config for this model specifies to tie text_decoder.cls.predictions.bias to text_decoder.cls.predictions.decoder.bias, but both are present in the checkpoints, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning\n",
|
| 209 |
+
"The tied weights mapping and config for this model specifies to tie text_decoder.bert.embeddings.word_embeddings.weight to text_decoder.cls.predictions.decoder.weight, but both are present in the checkpoints, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning\n",
|
| 210 |
+
"\u001b[1mBlipForConditionalGeneration LOAD REPORT\u001b[0m from: Salesforce/blip-image-captioning-base\n",
|
| 211 |
+
"Key | Status | | \n",
|
| 212 |
+
"------------------------------------------+------------+--+-\n",
|
| 213 |
+
"text_decoder.bert.embeddings.position_ids | UNEXPECTED | | \n",
|
| 214 |
+
"\n",
|
| 215 |
+
"\u001b[3mNotes:\n",
|
| 216 |
+
"- UNEXPECTED\u001b[3m\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch.\u001b[0m\n"
|
| 217 |
+
]
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"name": "stdout",
|
| 221 |
+
"output_type": "stream",
|
| 222 |
+
"text": [
|
| 223 |
+
"✅ Gradient checkpointing enabled\n",
|
| 224 |
+
"✅ Model loaded: Salesforce/blip-image-captioning-base\n"
|
| 225 |
+
]
|
| 226 |
+
}
|
| 227 |
+
],
|
| 228 |
+
"source": [
|
| 229 |
+
"processor = BlipProcessor.from_pretrained(cfg.model_id)\n",
|
| 230 |
+
"model = BlipForConditionalGeneration.from_pretrained(cfg.model_id)\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"# Force 224px images (lighter for Mac)\n",
|
| 233 |
+
"try:\n",
|
| 234 |
+
" processor.image_processor.size = {\"height\": cfg.image_size, \"width\": cfg.image_size}\n",
|
| 235 |
+
"except Exception as e:\n",
|
| 236 |
+
" print(f\"⚠️ Could not set image size: {e}\")\n",
|
| 237 |
+
"\n",
|
| 238 |
+
"# Memory helpers\n",
|
| 239 |
+
"try:\n",
|
| 240 |
+
" model.gradient_checkpointing_enable()\n",
|
| 241 |
+
" print(\"✅ Gradient checkpointing enabled\")\n",
|
| 242 |
+
"except Exception as e:\n",
|
| 243 |
+
" print(f\"⚠️ Gradient checkpointing failed: {e}\")\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"model.config.use_cache = False # must be False when using gradient checkpointing\n",
|
| 246 |
+
"model.to(device)\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"print(f\"✅ Model loaded: {cfg.model_id}\")\n"
|
| 249 |
+
]
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"cell_type": "code",
|
| 253 |
+
"execution_count": 9,
|
| 254 |
+
"id": "ae518a72",
|
| 255 |
+
"metadata": {},
|
| 256 |
+
"outputs": [],
|
| 257 |
+
"source": [
|
| 258 |
+
"def collate_fn(examples):\n",
|
| 259 |
+
" images = [ex[\"image\"].convert(\"RGB\") for ex in examples]\n",
|
| 260 |
+
" # pick one random caption per image\n",
|
| 261 |
+
" captions = [random.choice(ex[\"captions\"]) for ex in examples]\n",
|
| 262 |
+
"\n",
|
| 263 |
+
" encoding = processor(\n",
|
| 264 |
+
" images=images,\n",
|
| 265 |
+
" text=captions,\n",
|
| 266 |
+
" padding=\"max_length\",\n",
|
| 267 |
+
" truncation=True,\n",
|
| 268 |
+
" max_length=cfg.max_target_len,\n",
|
| 269 |
+
" return_tensors=\"pt\",\n",
|
| 270 |
+
" )\n",
|
| 271 |
+
"\n",
|
| 272 |
+
" # BLIP needs `labels` = `input_ids` for captioning loss\n",
|
| 273 |
+
" encoding[\"labels\"] = encoding[\"input_ids\"].clone()\n",
|
| 274 |
+
"\n",
|
| 275 |
+
" return encoding\n",
|
| 276 |
+
"\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"train_loader = DataLoader(\n",
|
| 279 |
+
" train_ds,\n",
|
| 280 |
+
" batch_size=cfg.batch_size,\n",
|
| 281 |
+
" shuffle=True,\n",
|
| 282 |
+
" num_workers=cfg.num_workers,\n",
|
| 283 |
+
" collate_fn=collate_fn,\n",
|
| 284 |
+
" pin_memory=True,\n",
|
| 285 |
+
")\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"val_loader = DataLoader(\n",
|
| 288 |
+
" val_ds,\n",
|
| 289 |
+
" batch_size=cfg.batch_size,\n",
|
| 290 |
+
" shuffle=False,\n",
|
| 291 |
+
" num_workers=cfg.num_workers,\n",
|
| 292 |
+
" collate_fn=collate_fn,\n",
|
| 293 |
+
" pin_memory=True,\n",
|
| 294 |
+
")"
|
| 295 |
+
]
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"cell_type": "code",
|
| 299 |
+
"execution_count": 10,
|
| 300 |
+
"id": "becf6f22",
|
| 301 |
+
"metadata": {},
|
| 302 |
+
"outputs": [
|
| 303 |
+
{
|
| 304 |
+
"name": "stdout",
|
| 305 |
+
"output_type": "stream",
|
| 306 |
+
"text": [
|
| 307 |
+
"✅ Update steps: 32, Warmup: 0\n"
|
| 308 |
+
]
|
| 309 |
+
}
|
| 310 |
+
],
|
| 311 |
+
"source": [
|
| 312 |
+
"optimizer = AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"total_update_steps = math.ceil(len(train_loader) / cfg.grad_accum) * cfg.epochs\n",
|
| 315 |
+
"warmup_steps = int(total_update_steps * cfg.warmup_ratio)\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"scheduler = get_cosine_schedule_with_warmup(\n",
|
| 318 |
+
" optimizer,\n",
|
| 319 |
+
" num_warmup_steps=warmup_steps,\n",
|
| 320 |
+
" num_training_steps=total_update_steps,\n",
|
| 321 |
+
")\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"print(f\"✅ Update steps: {total_update_steps}, Warmup: {warmup_steps}\")\n"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"cell_type": "code",
|
| 328 |
+
"execution_count": 11,
|
| 329 |
+
"id": "4134441d",
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"outputs": [
|
| 332 |
+
{
|
| 333 |
+
"name": "stdout",
|
| 334 |
+
"output_type": "stream",
|
| 335 |
+
"text": [
|
| 336 |
+
"✅ Checkpoint helpers ready\n"
|
| 337 |
+
]
|
| 338 |
+
}
|
| 339 |
+
],
|
| 340 |
+
"source": [
|
| 341 |
+
"def save_ckpt(step, epoch):\n",
|
| 342 |
+
" \"\"\"\n",
|
| 343 |
+
" Save model weights, processor, and training state to cfg.out_dir.\n",
|
| 344 |
+
" Directory: out_dir/ckpt_step{step}_epoch{epoch}\n",
|
| 345 |
+
" \"\"\"\n",
|
| 346 |
+
" path = os.path.join(cfg.out_dir, f\"ckpt_step{step}_epoch{epoch}\")\n",
|
| 347 |
+
" os.makedirs(path, exist_ok=True)\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" # Save model weights + config in HF format\n",
|
| 350 |
+
" model.save_pretrained(path)\n",
|
| 351 |
+
" processor.save_pretrained(path)\n",
|
| 352 |
+
"\n",
|
| 353 |
+
" # Save optimizer/scheduler state, step, epoch\n",
|
| 354 |
+
" torch.save(\n",
|
| 355 |
+
" {\n",
|
| 356 |
+
" \"step\": step,\n",
|
| 357 |
+
" \"epoch\": epoch,\n",
|
| 358 |
+
" \"optimizer\": optimizer.state_dict(),\n",
|
| 359 |
+
" \"scheduler\": scheduler.state_dict(),\n",
|
| 360 |
+
" \"cfg\": cfg.__dict__,\n",
|
| 361 |
+
" },\n",
|
| 362 |
+
" os.path.join(path, \"train_state.pt\"),\n",
|
| 363 |
+
" )\n",
|
| 364 |
+
"\n",
|
| 365 |
+
" print(f\"✅ Checkpoint saved: {path}\")\n",
|
| 366 |
+
"\n",
|
| 367 |
+
"\n",
|
| 368 |
+
"def load_ckpt(path):\n",
|
| 369 |
+
" \"\"\"\n",
|
| 370 |
+
" Load model + optimizer/scheduler from a checkpoint directory.\n",
|
| 371 |
+
" \"\"\"\n",
|
| 372 |
+
" # Load model weights\n",
|
| 373 |
+
" loaded_model = BlipForConditionalGeneration.from_pretrained(path)\n",
|
| 374 |
+
" model.load_state_dict(loaded_model.state_dict())\n",
|
| 375 |
+
"\n",
|
| 376 |
+
" # Load training state\n",
|
| 377 |
+
" state = torch.load(os.path.join(path, \"train_state.pt\"), map_location=\"cpu\")\n",
|
| 378 |
+
" optimizer.load_state_dict(state[\"optimizer\"])\n",
|
| 379 |
+
" scheduler.load_state_dict(state[\"scheduler\"])\n",
|
| 380 |
+
"\n",
|
| 381 |
+
" print(f\"✅ Resumed from step {state['step']}, epoch {state['epoch']}\")\n",
|
| 382 |
+
" return state[\"step\"], state[\"epoch\"]\n",
|
| 383 |
+
"\n",
|
| 384 |
+
"\n",
|
| 385 |
+
"print(\"✅ Checkpoint helpers ready\")\n"
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"cell_type": "code",
|
| 390 |
+
"execution_count": 12,
|
| 391 |
+
"id": "c323b9bb",
|
| 392 |
+
"metadata": {},
|
| 393 |
+
"outputs": [
|
| 394 |
+
{
|
| 395 |
+
"name": "stderr",
|
| 396 |
+
"output_type": "stream",
|
| 397 |
+
"text": [
|
| 398 |
+
"Epoch 1/1: 0%| | 0/250 [00:00<?, ?it/s]/Users/makumar/Documents/.venv/lib/python3.14/site-packages/torch/utils/data/dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, device pinned memory won't be used.\n",
|
| 399 |
+
" super().__init__(loader)\n",
|
| 400 |
+
"Epoch 1/1: 100%|██████████| 250/250 [01:05<00:00, 3.82it/s, loss=6.4825, lr=9.61e-08]\n",
|
| 401 |
+
"Writing model shards: 100%|██████████| 1/1 [00:00<00:00, 2.02it/s]\n"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"name": "stdout",
|
| 406 |
+
"output_type": "stream",
|
| 407 |
+
"text": [
|
| 408 |
+
"✅ Checkpoint saved: ./blip_coco_ft_mps/ckpt_step31_epoch1\n",
|
| 409 |
+
"✅ Training complete in 1.15 minutes\n"
|
| 410 |
+
]
|
| 411 |
+
}
|
| 412 |
+
],
|
| 413 |
+
"source": [
|
| 414 |
+
"model.train()\n",
|
| 415 |
+
"\n",
|
| 416 |
+
"global_step = 0\n",
|
| 417 |
+
"t0 = time.time()\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"for epoch in range(1, cfg.epochs + 1):\n",
|
| 420 |
+
" pbar = tqdm(train_loader, desc=f\"Epoch {epoch}/{cfg.epochs}\")\n",
|
| 421 |
+
" running_loss = 0.0\n",
|
| 422 |
+
"\n",
|
| 423 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
| 424 |
+
"\n",
|
| 425 |
+
" for i, batch in enumerate(pbar, start=1):\n",
|
| 426 |
+
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
| 427 |
+
"\n",
|
| 428 |
+
" out = model(**batch) # model returns loss when labels are passed [web:17]\n",
|
| 429 |
+
" loss = out.loss / cfg.grad_accum\n",
|
| 430 |
+
" loss.backward()\n",
|
| 431 |
+
"\n",
|
| 432 |
+
" running_loss += loss.item()\n",
|
| 433 |
+
"\n",
|
| 434 |
+
" if i % cfg.grad_accum == 0:\n",
|
| 435 |
+
" torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)\n",
|
| 436 |
+
" optimizer.step()\n",
|
| 437 |
+
" scheduler.step()\n",
|
| 438 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
| 439 |
+
"\n",
|
| 440 |
+
" global_step += 1\n",
|
| 441 |
+
"\n",
|
| 442 |
+
" if global_step % cfg.log_every == 0:\n",
|
| 443 |
+
" avg_loss = running_loss / cfg.log_every\n",
|
| 444 |
+
" running_loss = 0.0\n",
|
| 445 |
+
" pbar.set_postfix({\n",
|
| 446 |
+
" \"loss\": f\"{avg_loss:.4f}\",\n",
|
| 447 |
+
" \"lr\": f\"{scheduler.get_last_lr()[0]:.2e}\",\n",
|
| 448 |
+
" })\n",
|
| 449 |
+
"\n",
|
| 450 |
+
" if global_step % cfg.save_every_steps == 0:\n",
|
| 451 |
+
" save_ckpt(global_step, epoch)\n",
|
| 452 |
+
"\n",
|
| 453 |
+
" # Save checkpoint at end of epoch\n",
|
| 454 |
+
" save_ckpt(global_step, epoch)\n",
|
| 455 |
+
"\n",
|
| 456 |
+
"elapsed = (time.time() - t0) / 60.0\n",
|
| 457 |
+
"print(f\"✅ Training complete in {elapsed:.2f} minutes\")\n"
|
| 458 |
+
]
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"cell_type": "code",
|
| 462 |
+
"execution_count": 13,
|
| 463 |
+
"id": "f83558b0",
|
| 464 |
+
"metadata": {},
|
| 465 |
+
"outputs": [
|
| 466 |
+
{
|
| 467 |
+
"name": "stdout",
|
| 468 |
+
"output_type": "stream",
|
| 469 |
+
"text": [
|
| 470 |
+
"Sample predictions:\n",
|
| 471 |
+
"\n",
|
| 472 |
+
"GT: A group of people kneeling down beside some sheep.\n",
|
| 473 |
+
"Pred: a group of people standing around a dog on a leash\n",
|
| 474 |
+
"--------------------------------------------------------------------------------\n",
|
| 475 |
+
"GT: Two skiers prepare to make their way past an embankment\n",
|
| 476 |
+
"Pred: a group of people riding horses through a snow covered field\n",
|
| 477 |
+
"--------------------------------------------------------------------------------\n",
|
| 478 |
+
"GT: A person on skis skiing down a mountain.\n",
|
| 479 |
+
"Pred: a person skiing down a snow covered slope\n",
|
| 480 |
+
"--------------------------------------------------------------------------------\n",
|
| 481 |
+
"✅ Inference test complete\n"
|
| 482 |
+
]
|
| 483 |
+
}
|
| 484 |
+
],
|
| 485 |
+
"source": [
|
| 486 |
+
"model.eval()\n",
|
| 487 |
+
"\n",
|
| 488 |
+
"@torch.no_grad()\n",
|
| 489 |
+
"def generate_caption(pil_image, max_new_tokens=30, num_beams=3):\n",
|
| 490 |
+
" inputs = processor(images=pil_image.convert(\"RGB\"), return_tensors=\"pt\")\n",
|
| 491 |
+
" inputs = {k: v.to(device) for k, v in inputs.items()}\n",
|
| 492 |
+
" ids = model.generate(\n",
|
| 493 |
+
" **inputs,\n",
|
| 494 |
+
" max_new_tokens=max_new_tokens,\n",
|
| 495 |
+
" num_beams=num_beams,\n",
|
| 496 |
+
" )\n",
|
| 497 |
+
" return processor.decode(ids[0], skip_special_tokens=True)\n",
|
| 498 |
+
"\n",
|
| 499 |
+
"print(\"Sample predictions:\\n\")\n",
|
| 500 |
+
"for idx in [0, 1, 2]:\n",
|
| 501 |
+
" ex = val_ds[idx]\n",
|
| 502 |
+
" gt = ex[\"captions\"][0]\n",
|
| 503 |
+
" pred = generate_caption(ex[\"image\"])\n",
|
| 504 |
+
" print(f\"GT: {gt}\")\n",
|
| 505 |
+
" print(f\"Pred: {pred}\")\n",
|
| 506 |
+
" print(\"-\" * 80)\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"model.train()\n",
|
| 509 |
+
"print(\"✅ Inference test complete\")\n"
|
| 510 |
+
]
|
| 511 |
+
},
|
| 512 |
+
{
|
| 513 |
+
"cell_type": "code",
|
| 514 |
+
"execution_count": null,
|
| 515 |
+
"id": "c246206b",
|
| 516 |
+
"metadata": {},
|
| 517 |
+
"outputs": [],
|
| 518 |
+
"source": []
|
| 519 |
+
}
|
| 520 |
+
],
|
| 521 |
+
"metadata": {
|
| 522 |
+
"kernelspec": {
|
| 523 |
+
"display_name": ".venv",
|
| 524 |
+
"language": "python",
|
| 525 |
+
"name": "python3"
|
| 526 |
+
},
|
| 527 |
+
"language_info": {
|
| 528 |
+
"codemirror_mode": {
|
| 529 |
+
"name": "ipython",
|
| 530 |
+
"version": 3
|
| 531 |
+
},
|
| 532 |
+
"file_extension": ".py",
|
| 533 |
+
"mimetype": "text/x-python",
|
| 534 |
+
"name": "python",
|
| 535 |
+
"nbconvert_exporter": "python",
|
| 536 |
+
"pygments_lexer": "ipython3",
|
| 537 |
+
"version": "3.14.2"
|
| 538 |
+
}
|
| 539 |
+
},
|
| 540 |
+
"nbformat": 4,
|
| 541 |
+
"nbformat_minor": 5
|
| 542 |
+
}
|
models/blip_tuner.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models/blip_tuner.py
|
| 3 |
+
====================
|
| 4 |
+
Baseline 3 — Multimodal Mixture Attention (BLIP)
|
| 5 |
+
|
| 6 |
+
Architecture: BLIP's MED (Multimodal Encoder-Decoder) architecture injects
|
| 7 |
+
specialized gated cross-attention between self-attention and feed-forward layers.
|
| 8 |
+
The visual encoder output (image patch embeddings) is queried by the text decoder
|
| 9 |
+
via cross-attention that is applied carefully at each decoder layer.
|
| 10 |
+
|
| 11 |
+
This module also provides `generate_with_mask()` for inference-time ablation
|
| 12 |
+
experiments that manipulate the encoder_attention_mask to test spatial restrictions.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import torch
|
| 17 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_blip_model(cfg, device):
|
| 21 |
+
"""
|
| 22 |
+
Loads BLIP model and processor with MPS and memory optimizations.
|
| 23 |
+
"""
|
| 24 |
+
processor = BlipProcessor.from_pretrained(cfg.model_id, use_fast=True)
|
| 25 |
+
model = BlipForConditionalGeneration.from_pretrained(cfg.model_id)
|
| 26 |
+
|
| 27 |
+
# Force 224px images for efficiency (especially on Mac/MPS)
|
| 28 |
+
try:
|
| 29 |
+
processor.image_processor.size = {"height": cfg.image_size, "width": cfg.image_size}
|
| 30 |
+
print(f"✅ Image size set to {cfg.image_size}px")
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"⚠️ Could not set image size: {e}")
|
| 33 |
+
|
| 34 |
+
# Gradient checkpointing for VRAM efficiency
|
| 35 |
+
try:
|
| 36 |
+
model.gradient_checkpointing_enable()
|
| 37 |
+
print("✅ Gradient checkpointing enabled (BLIP)")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"⚠️ Gradient checkpointing failed: {e}")
|
| 40 |
+
|
| 41 |
+
model.config.use_cache = False # Must be False with gradient checkpointing
|
| 42 |
+
model.to(device)
|
| 43 |
+
|
| 44 |
+
n_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 45 |
+
print(f"✅ BLIP loaded on {device}: {cfg.model_id} ({n_params:.1f}M params)")
|
| 46 |
+
|
| 47 |
+
return model, processor
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def generate_with_mask(model, processor, image_pil=None, device=None,
|
| 51 |
+
pixel_values=None,
|
| 52 |
+
encoder_hidden_states=None,
|
| 53 |
+
encoder_attention_mask=None,
|
| 54 |
+
max_new_tokens=32, num_beams=4):
|
| 55 |
+
"""
|
| 56 |
+
Generate a caption for a single PIL image (or pre-computed tensors) with an ablation mask.
|
| 57 |
+
|
| 58 |
+
Ablation modes supported:
|
| 59 |
+
- Baseline: 197 patches visible
|
| 60 |
+
- Random Dropout: 50% spatial patches masked
|
| 61 |
+
- Center-Focus: Inner 8x8 patches visible
|
| 62 |
+
- Squint: Requires passing pre-pooled `encoder_hidden_states` of shape (B, 2, C).
|
| 63 |
+
"""
|
| 64 |
+
model.eval()
|
| 65 |
+
|
| 66 |
+
# 1. Get pixel values
|
| 67 |
+
if pixel_values is None and image_pil is not None:
|
| 68 |
+
inputs = processor(images=image_pil, return_tensors="pt").to(device)
|
| 69 |
+
pixel_values = inputs["pixel_values"]
|
| 70 |
+
|
| 71 |
+
batch_size = pixel_values.shape[0] if pixel_values is not None else encoder_hidden_states.shape[0]
|
| 72 |
+
dev = pixel_values.device if pixel_values is not None else encoder_hidden_states.device
|
| 73 |
+
|
| 74 |
+
# 2. Extract visual features if not pre-provided (e.g., Squint mode provides them)
|
| 75 |
+
if encoder_hidden_states is None:
|
| 76 |
+
vision_outputs = model.vision_model(pixel_values=pixel_values)
|
| 77 |
+
encoder_hidden_states = vision_outputs[0]
|
| 78 |
+
|
| 79 |
+
# 3. Handle encoder_attention_mask default (Baseline = all ones)
|
| 80 |
+
if encoder_attention_mask is None:
|
| 81 |
+
encoder_attention_mask = torch.ones(
|
| 82 |
+
encoder_hidden_states.size()[:-1], dtype=torch.long, device=dev
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
encoder_attention_mask = encoder_attention_mask.to(dev)
|
| 86 |
+
|
| 87 |
+
# 4. Prepare decoder input IDs (BOS token)
|
| 88 |
+
input_ids = (
|
| 89 |
+
torch.LongTensor([[model.decoder_input_ids, model.config.text_config.eos_token_id]])
|
| 90 |
+
.repeat(batch_size, 1)
|
| 91 |
+
.to(dev)
|
| 92 |
+
)
|
| 93 |
+
input_ids[:, 0] = model.config.text_config.bos_token_id
|
| 94 |
+
|
| 95 |
+
# 5. Bypass the outer model.generate() to avoid hardcoded mask conflicts
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
output_ids = model.text_decoder.generate(
|
| 98 |
+
input_ids=input_ids[:, :-1],
|
| 99 |
+
eos_token_id=model.config.text_config.sep_token_id,
|
| 100 |
+
pad_token_id=model.config.text_config.pad_token_id,
|
| 101 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 102 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 103 |
+
max_new_tokens=max_new_tokens,
|
| 104 |
+
num_beams=num_beams,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
captions = processor.batch_decode(output_ids, skip_special_tokens=True)
|
| 108 |
+
return captions
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def save_ckpt(model, processor, optimizer, scheduler, step, epoch, cfg_dict, path):
|
| 112 |
+
"""
|
| 113 |
+
Save model weights, processor, and training state.
|
| 114 |
+
"""
|
| 115 |
+
os.makedirs(path, exist_ok=True)
|
| 116 |
+
model.save_pretrained(path)
|
| 117 |
+
processor.save_pretrained(path)
|
| 118 |
+
|
| 119 |
+
torch.save(
|
| 120 |
+
{
|
| 121 |
+
"step": step,
|
| 122 |
+
"epoch": epoch,
|
| 123 |
+
"optimizer": optimizer.state_dict() if optimizer else None,
|
| 124 |
+
"scheduler": scheduler.state_dict() if scheduler else None,
|
| 125 |
+
"cfg": cfg_dict,
|
| 126 |
+
},
|
| 127 |
+
os.path.join(path, "train_state.pt"),
|
| 128 |
+
)
|
| 129 |
+
print(f"✅ BLIP checkpoint saved: {path}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def load_ckpt(model, optimizer, scheduler, path):
|
| 133 |
+
"""
|
| 134 |
+
Load model + optimizer/scheduler from a checkpoint directory.
|
| 135 |
+
"""
|
| 136 |
+
loaded_model = BlipForConditionalGeneration.from_pretrained(path)
|
| 137 |
+
model.load_state_dict(loaded_model.state_dict())
|
| 138 |
+
|
| 139 |
+
state_path = os.path.join(path, "train_state.pt")
|
| 140 |
+
if os.path.exists(state_path):
|
| 141 |
+
state = torch.load(state_path, map_location="cpu")
|
| 142 |
+
if optimizer and state.get("optimizer"):
|
| 143 |
+
optimizer.load_state_dict(state["optimizer"])
|
| 144 |
+
if scheduler and state.get("scheduler"):
|
| 145 |
+
scheduler.load_state_dict(state["scheduler"])
|
| 146 |
+
print(f"✅ Resumed from step {state.get('step', '?')}, epoch {state.get('epoch', '?')}")
|
| 147 |
+
return state.get("step", 0), state.get("epoch", 1)
|
| 148 |
+
|
| 149 |
+
print("✅ Model weights loaded, no training state found.")
|
| 150 |
+
return 0, 1
|
models/custom_vlm.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models/custom_vlm.py
|
| 3 |
+
=====================
|
| 4 |
+
Advanced Master-Hack — Visual Prefix-Tuning (Shakespeare + ViT)
|
| 5 |
+
|
| 6 |
+
Architecture: A frozen pre-trained ViT (google/vit-base-patch16-224-in21k)
|
| 7 |
+
is fused with a custom character-level causal Transformer decoder trained on
|
| 8 |
+
Shakespeare text. A trainable MLP projection layer bridges the ViT's
|
| 9 |
+
768-dim output to the decoder's 384-dim embedding space.
|
| 10 |
+
|
| 11 |
+
MODALITY FUSION:
|
| 12 |
+
ViT → Project(768→384) → [visual_prefix | char_embeddings] → CausalSelfAttention
|
| 13 |
+
|
| 14 |
+
TRAINING REGIME:
|
| 15 |
+
- ViT: FROZEN (always)
|
| 16 |
+
- Shakespeare Decoder: UNFROZEN during fine-tuning (adapts to COCO captions)
|
| 17 |
+
- visual_projection: TRAINABLE (learned bridge)
|
| 18 |
+
|
| 19 |
+
Weight Loading Strategy:
|
| 20 |
+
The Shakespeare checkpoint uses a custom per-head architecture with keys like:
|
| 21 |
+
blocks.N.sa_head.heads.M.{key,query,value}.weight
|
| 22 |
+
These are remapped to PyTorch nn.TransformerEncoder's fused format:
|
| 23 |
+
decoder_blocks.layers.N.self_attn.in_proj_weight
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from transformers import ViTModel
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 33 |
+
# Character Vocabulary Helper
|
| 34 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 35 |
+
|
| 36 |
+
def build_char_vocab(text_corpus: str):
|
| 37 |
+
"""
|
| 38 |
+
Build a character-level vocabulary from a raw text corpus string.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
chars : sorted list of unique characters
|
| 42 |
+
char_to_idx : dict mapping char → int index
|
| 43 |
+
idx_to_char : dict mapping int index → char
|
| 44 |
+
vocab_size : int
|
| 45 |
+
"""
|
| 46 |
+
chars = sorted(set(text_corpus))
|
| 47 |
+
char_to_idx = {c: i for i, c in enumerate(chars)}
|
| 48 |
+
idx_to_char = {i: c for i, c in enumerate(chars)}
|
| 49 |
+
return chars, char_to_idx, idx_to_char, len(chars)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 53 |
+
# Model Definition
|
| 54 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 55 |
+
|
| 56 |
+
class CustomVLM(nn.Module):
|
| 57 |
+
"""
|
| 58 |
+
Visual Prefix-Tuning VLM.
|
| 59 |
+
|
| 60 |
+
Combines:
|
| 61 |
+
1. Frozen ViT image encoder (768-dim output)
|
| 62 |
+
2. Trainable MLP projection (768 → text_embed_dim)
|
| 63 |
+
3. Character-level causal Transformer decoder
|
| 64 |
+
(initialized from shakespeare_transformer.pt, then fine-tuned)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
NUM_VISUAL_TOKENS = 197 # ViT: 196 patches + 1 [CLS]
|
| 68 |
+
|
| 69 |
+
def __init__(self, vocab_size, text_embed_dim=384, n_heads=8, n_layers=8,
|
| 70 |
+
block_size=256, dropout=0.1):
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
# ── 1. Vision Encoder (Frozen) ──────────────────────────────────────
|
| 74 |
+
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 75 |
+
for param in self.vit.parameters():
|
| 76 |
+
param.requires_grad = False
|
| 77 |
+
|
| 78 |
+
vit_hidden_size = self.vit.config.hidden_size # 768
|
| 79 |
+
|
| 80 |
+
# ── 2. Trainable Bridge (MLP — like LLaVA) ──────────────────────────
|
| 81 |
+
self.visual_projection = nn.Sequential(
|
| 82 |
+
nn.Linear(vit_hidden_size, vit_hidden_size * 2),
|
| 83 |
+
nn.GELU(),
|
| 84 |
+
nn.Linear(vit_hidden_size * 2, text_embed_dim)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# ── 3. Character-Level Causal Transformer Decoder ───────────────────
|
| 88 |
+
self.token_embedding_table = nn.Embedding(vocab_size, text_embed_dim)
|
| 89 |
+
# Position table covers visual prefix (197) + max text (block_size)
|
| 90 |
+
self.position_embedding_table = nn.Embedding(
|
| 91 |
+
self.NUM_VISUAL_TOKENS + block_size, text_embed_dim
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
decoder_layer = nn.TransformerEncoderLayer(
|
| 95 |
+
d_model=text_embed_dim,
|
| 96 |
+
nhead=n_heads,
|
| 97 |
+
dim_feedforward=4 * text_embed_dim,
|
| 98 |
+
dropout=dropout,
|
| 99 |
+
batch_first=True,
|
| 100 |
+
)
|
| 101 |
+
self.decoder_blocks = nn.TransformerEncoder(decoder_layer, num_layers=n_layers)
|
| 102 |
+
|
| 103 |
+
self.ln_f = nn.LayerNorm(text_embed_dim)
|
| 104 |
+
self.lm_head = nn.Linear(text_embed_dim, vocab_size)
|
| 105 |
+
|
| 106 |
+
self.block_size = block_size
|
| 107 |
+
self.text_embed_dim = text_embed_dim
|
| 108 |
+
self.vocab_size = vocab_size
|
| 109 |
+
self.n_heads = n_heads
|
| 110 |
+
self.n_layers = n_layers
|
| 111 |
+
|
| 112 |
+
# ───────────────────────────────��─────────────────────────────────────────
|
| 113 |
+
# Weight Loading — with architecture remapping
|
| 114 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 115 |
+
|
| 116 |
+
def load_shakespeare_weights(self, path: str, device: str = "cpu") -> dict:
|
| 117 |
+
"""
|
| 118 |
+
Load pre-trained Shakespeare Transformer weights with full key remapping.
|
| 119 |
+
|
| 120 |
+
The Shakespeare checkpoint uses a custom per-head architecture:
|
| 121 |
+
blocks.N.sa_head.heads.M.{key,query,value}.weight (head_dim, embed_dim)
|
| 122 |
+
blocks.N.sa_head.proj.{weight,bias}
|
| 123 |
+
blocks.N.ffwd.net.{0,2}.{weight,bias}
|
| 124 |
+
blocks.N.ln{1,2}.{weight,bias}
|
| 125 |
+
|
| 126 |
+
These are remapped into PyTorch nn.TransformerEncoder's fused format:
|
| 127 |
+
decoder_blocks.layers.N.self_attn.in_proj_weight (3*embed_dim, embed_dim)
|
| 128 |
+
decoder_blocks.layers.N.self_attn.out_proj.{weight,bias}
|
| 129 |
+
decoder_blocks.layers.N.linear1.{weight,bias}
|
| 130 |
+
decoder_blocks.layers.N.linear2.{weight,bias}
|
| 131 |
+
decoder_blocks.layers.N.norm1.{weight,bias}
|
| 132 |
+
decoder_blocks.layers.N.norm2.{weight,bias}
|
| 133 |
+
"""
|
| 134 |
+
print(f"📖 Loading Shakespeare weights from: {path}")
|
| 135 |
+
|
| 136 |
+
raw = torch.load(path, map_location=device)
|
| 137 |
+
|
| 138 |
+
# Unwrap common checkpoint structures
|
| 139 |
+
if isinstance(raw, dict):
|
| 140 |
+
if "model_state" in raw:
|
| 141 |
+
state_dict = raw["model_state"]
|
| 142 |
+
elif "model" in raw:
|
| 143 |
+
state_dict = raw["model"]
|
| 144 |
+
elif "state_dict" in raw:
|
| 145 |
+
state_dict = raw["state_dict"]
|
| 146 |
+
else:
|
| 147 |
+
state_dict = raw
|
| 148 |
+
else:
|
| 149 |
+
raise TypeError(f"Unexpected checkpoint type: {type(raw)}")
|
| 150 |
+
|
| 151 |
+
# ── Discover Shakespeare architecture ────────────────────────────────
|
| 152 |
+
shk_blocks = set()
|
| 153 |
+
shk_heads = set()
|
| 154 |
+
for key in state_dict:
|
| 155 |
+
if key.startswith("blocks."):
|
| 156 |
+
parts = key.split(".")
|
| 157 |
+
shk_blocks.add(int(parts[1]))
|
| 158 |
+
if "heads" in key:
|
| 159 |
+
shk_heads.add(int(parts[4]))
|
| 160 |
+
|
| 161 |
+
n_shk_blocks = len(shk_blocks)
|
| 162 |
+
n_shk_heads = len(shk_heads) if shk_heads else self.n_heads
|
| 163 |
+
head_dim = self.text_embed_dim // self.n_heads
|
| 164 |
+
|
| 165 |
+
print(f" 📊 Shakespeare arch: {n_shk_blocks} blocks, {n_shk_heads} heads, "
|
| 166 |
+
f"head_dim={head_dim}")
|
| 167 |
+
print(f" 📊 Model arch: {self.n_layers} layers, {self.n_heads} heads")
|
| 168 |
+
|
| 169 |
+
# How many layers to load (min of checkpoint and model)
|
| 170 |
+
n_load = min(n_shk_blocks, self.n_layers)
|
| 171 |
+
n_heads_load = min(n_shk_heads, self.n_heads)
|
| 172 |
+
|
| 173 |
+
remapped = {}
|
| 174 |
+
|
| 175 |
+
# ── Remap decoder blocks ─────────────────────────────────────────────
|
| 176 |
+
for layer_idx in range(n_load):
|
| 177 |
+
prefix_src = f"blocks.{layer_idx}"
|
| 178 |
+
prefix_dst = f"decoder_blocks.layers.{layer_idx}"
|
| 179 |
+
|
| 180 |
+
# 1. Self-Attention: Fuse per-head Q, K, V into in_proj_weight
|
| 181 |
+
# Shakespeare: heads.M.query.weight (head_dim, embed_dim)
|
| 182 |
+
# Target: self_attn.in_proj_weight (3*embed_dim, embed_dim)
|
| 183 |
+
q_parts, k_parts, v_parts = [], [], []
|
| 184 |
+
for h in range(n_heads_load):
|
| 185 |
+
qk = f"{prefix_src}.sa_head.heads.{h}.query.weight"
|
| 186 |
+
kk = f"{prefix_src}.sa_head.heads.{h}.key.weight"
|
| 187 |
+
vk = f"{prefix_src}.sa_head.heads.{h}.value.weight"
|
| 188 |
+
if qk in state_dict and kk in state_dict and vk in state_dict:
|
| 189 |
+
q_parts.append(state_dict[qk])
|
| 190 |
+
k_parts.append(state_dict[kk])
|
| 191 |
+
v_parts.append(state_dict[vk])
|
| 192 |
+
|
| 193 |
+
if q_parts:
|
| 194 |
+
# Concatenate heads: each (head_dim, embed_dim) → (embed_dim, embed_dim)
|
| 195 |
+
Q_full = torch.cat(q_parts, dim=0) # (n_heads*head_dim, embed_dim)
|
| 196 |
+
K_full = torch.cat(k_parts, dim=0)
|
| 197 |
+
V_full = torch.cat(v_parts, dim=0)
|
| 198 |
+
# Fuse into in_proj_weight: [Q; K; V] → (3*embed_dim, embed_dim)
|
| 199 |
+
in_proj_weight = torch.cat([Q_full, K_full, V_full], dim=0)
|
| 200 |
+
remapped[f"{prefix_dst}.self_attn.in_proj_weight"] = in_proj_weight
|
| 201 |
+
|
| 202 |
+
# Create zero bias (Shakespeare has no Q/K/V bias)
|
| 203 |
+
remapped[f"{prefix_dst}.self_attn.in_proj_bias"] = torch.zeros(
|
| 204 |
+
3 * self.text_embed_dim
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# 2. Output projection
|
| 208 |
+
proj_w = f"{prefix_src}.sa_head.proj.weight"
|
| 209 |
+
proj_b = f"{prefix_src}.sa_head.proj.bias"
|
| 210 |
+
if proj_w in state_dict:
|
| 211 |
+
remapped[f"{prefix_dst}.self_attn.out_proj.weight"] = state_dict[proj_w]
|
| 212 |
+
if proj_b in state_dict:
|
| 213 |
+
remapped[f"{prefix_dst}.self_attn.out_proj.bias"] = state_dict[proj_b]
|
| 214 |
+
|
| 215 |
+
# 3. Feed-Forward Network
|
| 216 |
+
# Shakespeare: ffwd.net.0 → linear1, ffwd.net.2 → linear2
|
| 217 |
+
for shk_idx, tgt_name in [("0", "linear1"), ("2", "linear2")]:
|
| 218 |
+
wk = f"{prefix_src}.ffwd.net.{shk_idx}.weight"
|
| 219 |
+
bk = f"{prefix_src}.ffwd.net.{shk_idx}.bias"
|
| 220 |
+
if wk in state_dict:
|
| 221 |
+
remapped[f"{prefix_dst}.{tgt_name}.weight"] = state_dict[wk]
|
| 222 |
+
if bk in state_dict:
|
| 223 |
+
remapped[f"{prefix_dst}.{tgt_name}.bias"] = state_dict[bk]
|
| 224 |
+
|
| 225 |
+
# 4. Layer Norms: ln1 → norm1, ln2 → norm2
|
| 226 |
+
for shk_ln, tgt_ln in [("ln1", "norm1"), ("ln2", "norm2")]:
|
| 227 |
+
for suffix in ("weight", "bias"):
|
| 228 |
+
sk = f"{prefix_src}.{shk_ln}.{suffix}"
|
| 229 |
+
if sk in state_dict:
|
| 230 |
+
remapped[f"{prefix_dst}.{tgt_ln}.{suffix}"] = state_dict[sk]
|
| 231 |
+
|
| 232 |
+
# ── Non-decoder module weights ───────────────────────────────────────
|
| 233 |
+
# token_embedding_table
|
| 234 |
+
if "token_embedding_table.weight" in state_dict:
|
| 235 |
+
shk_emb = state_dict["token_embedding_table.weight"]
|
| 236 |
+
own_emb = self.token_embedding_table.weight
|
| 237 |
+
if shk_emb.shape == own_emb.shape:
|
| 238 |
+
remapped["token_embedding_table.weight"] = shk_emb
|
| 239 |
+
elif shk_emb.shape[1] == own_emb.shape[1]:
|
| 240 |
+
# Vocab size difference: copy what fits
|
| 241 |
+
n_copy = min(shk_emb.shape[0], own_emb.shape[0])
|
| 242 |
+
new_emb = own_emb.data.clone()
|
| 243 |
+
new_emb[:n_copy] = shk_emb[:n_copy]
|
| 244 |
+
remapped["token_embedding_table.weight"] = new_emb
|
| 245 |
+
|
| 246 |
+
# position_embedding_table: Shakespeare (256, 384) → Model (453, 384)
|
| 247 |
+
if "position_embedding_table.weight" in state_dict:
|
| 248 |
+
shk_pos = state_dict["position_embedding_table.weight"] # (256, 384)
|
| 249 |
+
own_pos = self.position_embedding_table.weight # (197+block_size, 384)
|
| 250 |
+
if shk_pos.shape == own_pos.shape:
|
| 251 |
+
remapped["position_embedding_table.weight"] = shk_pos
|
| 252 |
+
else:
|
| 253 |
+
# Expand: zero-init the full table, then copy Shakespeare positions
|
| 254 |
+
# into the TEXT portion (positions 197..197+256)
|
| 255 |
+
new_pos = torch.zeros_like(own_pos.data)
|
| 256 |
+
# Visual positions (0..196) get small random init
|
| 257 |
+
nn.init.normal_(new_pos[:self.NUM_VISUAL_TOKENS], std=0.02)
|
| 258 |
+
# Text positions: copy Shakespeare's first N positions
|
| 259 |
+
n_text_slots = own_pos.shape[0] - self.NUM_VISUAL_TOKENS
|
| 260 |
+
n_copy = min(shk_pos.shape[0], n_text_slots)
|
| 261 |
+
new_pos[self.NUM_VISUAL_TOKENS:self.NUM_VISUAL_TOKENS + n_copy] = shk_pos[:n_copy]
|
| 262 |
+
remapped["position_embedding_table.weight"] = new_pos
|
| 263 |
+
print(f" 📐 Position embeddings expanded: {shk_pos.shape} → {own_pos.shape}")
|
| 264 |
+
|
| 265 |
+
# ln_f (final layer norm)
|
| 266 |
+
for suffix in ("weight", "bias"):
|
| 267 |
+
k = f"ln_f.{suffix}"
|
| 268 |
+
if k in state_dict:
|
| 269 |
+
own_shape = getattr(self.ln_f, suffix).shape
|
| 270 |
+
if state_dict[k].shape == own_shape:
|
| 271 |
+
remapped[k] = state_dict[k]
|
| 272 |
+
|
| 273 |
+
# lm_head
|
| 274 |
+
if "lm_head.weight" in state_dict:
|
| 275 |
+
shk_lm = state_dict["lm_head.weight"]
|
| 276 |
+
own_lm = self.lm_head.weight
|
| 277 |
+
if shk_lm.shape == own_lm.shape:
|
| 278 |
+
remapped["lm_head.weight"] = shk_lm
|
| 279 |
+
elif shk_lm.shape[1] == own_lm.shape[1]:
|
| 280 |
+
n_copy = min(shk_lm.shape[0], own_lm.shape[0])
|
| 281 |
+
new_lm = own_lm.data.clone()
|
| 282 |
+
new_lm[:n_copy] = shk_lm[:n_copy]
|
| 283 |
+
remapped["lm_head.weight"] = new_lm
|
| 284 |
+
|
| 285 |
+
if "lm_head.bias" in state_dict:
|
| 286 |
+
shk_b = state_dict["lm_head.bias"]
|
| 287 |
+
own_b = self.lm_head.bias
|
| 288 |
+
if own_b is not None and shk_b.shape == own_b.shape:
|
| 289 |
+
remapped["lm_head.bias"] = shk_b
|
| 290 |
+
elif own_b is not None:
|
| 291 |
+
n_copy = min(shk_b.shape[0], own_b.shape[0])
|
| 292 |
+
new_b = own_b.data.clone()
|
| 293 |
+
new_b[:n_copy] = shk_b[:n_copy]
|
| 294 |
+
remapped["lm_head.bias"] = new_b
|
| 295 |
+
|
| 296 |
+
# ── Load remapped weights ─────────────────────────────────────────────
|
| 297 |
+
# Verify shapes before loading
|
| 298 |
+
own_state = self.state_dict()
|
| 299 |
+
valid_remapped = {}
|
| 300 |
+
shape_mismatches = []
|
| 301 |
+
for k, v in remapped.items():
|
| 302 |
+
if k in own_state:
|
| 303 |
+
if own_state[k].shape == v.shape:
|
| 304 |
+
valid_remapped[k] = v
|
| 305 |
+
else:
|
| 306 |
+
shape_mismatches.append(
|
| 307 |
+
f" {k}: ckpt={v.shape} vs model={own_state[k].shape}"
|
| 308 |
+
)
|
| 309 |
+
else:
|
| 310 |
+
shape_mismatches.append(f" {k}: not in model state_dict")
|
| 311 |
+
|
| 312 |
+
result = self.load_state_dict(valid_remapped, strict=False)
|
| 313 |
+
|
| 314 |
+
print(f" ✅ Successfully loaded {len(valid_remapped)} weight tensors (of {len(state_dict)} in checkpoint)")
|
| 315 |
+
|
| 316 |
+
if shape_mismatches:
|
| 317 |
+
print(f" ⚠️ {len(shape_mismatches)} shape mismatches (skipped):")
|
| 318 |
+
for msg in shape_mismatches[:5]:
|
| 319 |
+
print(msg)
|
| 320 |
+
|
| 321 |
+
# Count decoder keys that were successfully loaded
|
| 322 |
+
decoder_loaded = sum(1 for k in valid_remapped if k.startswith("decoder_blocks"))
|
| 323 |
+
total_decoder = sum(1 for k in own_state if k.startswith("decoder_blocks"))
|
| 324 |
+
print(f" 📊 Decoder coverage: {decoder_loaded}/{total_decoder} tensors loaded")
|
| 325 |
+
|
| 326 |
+
return {
|
| 327 |
+
"loaded": list(valid_remapped.keys()),
|
| 328 |
+
"missing": result.missing_keys,
|
| 329 |
+
"unexpected": result.unexpected_keys,
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 333 |
+
# Freezing / Unfreezing / Parameter Counting
|
| 334 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 335 |
+
|
| 336 |
+
def freeze_decoder(self):
|
| 337 |
+
"""Freeze the Shakespeare decoder so only visual_projection trains."""
|
| 338 |
+
for name, param in self.named_parameters():
|
| 339 |
+
if not name.startswith("visual_projection"):
|
| 340 |
+
param.requires_grad = False
|
| 341 |
+
# Ensure ViT is frozen
|
| 342 |
+
for param in self.vit.parameters():
|
| 343 |
+
param.requires_grad = False
|
| 344 |
+
|
| 345 |
+
def unfreeze_decoder(self):
|
| 346 |
+
"""
|
| 347 |
+
Unfreeze the decoder for fine-tuning while keeping ViT frozen.
|
| 348 |
+
|
| 349 |
+
This allows the decoder to adapt from Shakespeare text to COCO captions.
|
| 350 |
+
The visual_projection is also trainable.
|
| 351 |
+
"""
|
| 352 |
+
# First, freeze everything
|
| 353 |
+
for param in self.parameters():
|
| 354 |
+
param.requires_grad = False
|
| 355 |
+
|
| 356 |
+
# Unfreeze visual_projection (always trainable)
|
| 357 |
+
for param in self.visual_projection.parameters():
|
| 358 |
+
param.requires_grad = True
|
| 359 |
+
|
| 360 |
+
# Unfreeze ALL decoder components
|
| 361 |
+
for param in self.token_embedding_table.parameters():
|
| 362 |
+
param.requires_grad = True
|
| 363 |
+
for param in self.position_embedding_table.parameters():
|
| 364 |
+
param.requires_grad = True
|
| 365 |
+
for param in self.decoder_blocks.parameters():
|
| 366 |
+
param.requires_grad = True
|
| 367 |
+
for param in self.ln_f.parameters():
|
| 368 |
+
param.requires_grad = True
|
| 369 |
+
for param in self.lm_head.parameters():
|
| 370 |
+
param.requires_grad = True
|
| 371 |
+
|
| 372 |
+
# ViT stays FROZEN
|
| 373 |
+
for param in self.vit.parameters():
|
| 374 |
+
param.requires_grad = False
|
| 375 |
+
|
| 376 |
+
def get_param_groups(self, projection_lr=1e-4, decoder_lr=5e-5):
|
| 377 |
+
"""
|
| 378 |
+
Return optimizer param groups with discriminative learning rates.
|
| 379 |
+
|
| 380 |
+
- visual_projection: higher LR (learning from scratch)
|
| 381 |
+
- decoder: lower LR (gentle adaptation from Shakespeare)
|
| 382 |
+
"""
|
| 383 |
+
projection_params = []
|
| 384 |
+
decoder_params = []
|
| 385 |
+
|
| 386 |
+
for name, param in self.named_parameters():
|
| 387 |
+
if not param.requires_grad:
|
| 388 |
+
continue
|
| 389 |
+
if name.startswith("visual_projection"):
|
| 390 |
+
projection_params.append(param)
|
| 391 |
+
else:
|
| 392 |
+
decoder_params.append(param)
|
| 393 |
+
|
| 394 |
+
return [
|
| 395 |
+
{"params": projection_params, "lr": projection_lr},
|
| 396 |
+
{"params": decoder_params, "lr": decoder_lr},
|
| 397 |
+
]
|
| 398 |
+
|
| 399 |
+
def trainable_params(self):
|
| 400 |
+
"""Return count of trainable parameters."""
|
| 401 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 402 |
+
|
| 403 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 404 |
+
# Forward Pass
|
| 405 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 406 |
+
|
| 407 |
+
def forward(self, pixel_values, text_input_ids, text_targets=None):
|
| 408 |
+
B, T = text_input_ids.shape
|
| 409 |
+
|
| 410 |
+
# ── Image Encoding (frozen ViT) ──────────────────────────────────────
|
| 411 |
+
with torch.no_grad():
|
| 412 |
+
vit_outputs = self.vit(pixel_values=pixel_values)
|
| 413 |
+
image_embeds = vit_outputs.last_hidden_state # (B, 197, 768)
|
| 414 |
+
|
| 415 |
+
# ── Project to text embedding space ──────────────────────────────────
|
| 416 |
+
visual_prefix = self.visual_projection(image_embeds) # (B, 197, 384)
|
| 417 |
+
num_visual = visual_prefix.shape[1] # 197
|
| 418 |
+
|
| 419 |
+
# ── Text Embeddings ───────────────────────────────────────────────────
|
| 420 |
+
T_clipped = min(T, self.block_size)
|
| 421 |
+
text_in = text_input_ids[:, :T_clipped]
|
| 422 |
+
tok_emb = self.token_embedding_table(text_in) # (B, T, 384)
|
| 423 |
+
|
| 424 |
+
# ── Positional Embeddings (covers full combined sequence) ─────────────
|
| 425 |
+
# Positions 0..196 → visual prefix, 197..197+T → text tokens
|
| 426 |
+
total_len = num_visual + T_clipped
|
| 427 |
+
pos_ids = torch.arange(total_len, device=text_in.device)
|
| 428 |
+
pos_emb = self.position_embedding_table(pos_ids) # (num_visual+T, 384)
|
| 429 |
+
|
| 430 |
+
vis_pos = pos_emb[:num_visual] # (197, 384)
|
| 431 |
+
txt_pos = pos_emb[num_visual:] # (T, 384)
|
| 432 |
+
|
| 433 |
+
visual_emb = visual_prefix + vis_pos # (B, 197, 384)
|
| 434 |
+
text_emb = tok_emb + txt_pos # (B, T, 384)
|
| 435 |
+
|
| 436 |
+
# ── Fusion: [visual_prefix | text_emb] ───────────────────────────────
|
| 437 |
+
combined = torch.cat([visual_emb, text_emb], dim=1) # (B, 197+T, 384)
|
| 438 |
+
tot = combined.shape[1]
|
| 439 |
+
|
| 440 |
+
# ── Causal Attention Mask ─────────────────────────────────────────────
|
| 441 |
+
# Visual tokens attend to each other freely.
|
| 442 |
+
# Text tokens attend to all visual tokens + causally to previous text.
|
| 443 |
+
mask = torch.full((tot, tot), float("-inf"), device=text_in.device)
|
| 444 |
+
mask[:num_visual, :num_visual] = 0.0 # visual→visual: free
|
| 445 |
+
mask[num_visual:, :num_visual] = 0.0 # text→visual: free
|
| 446 |
+
causal = torch.triu(
|
| 447 |
+
torch.full((T_clipped, T_clipped), float("-inf"), device=text_in.device),
|
| 448 |
+
diagonal=1,
|
| 449 |
+
)
|
| 450 |
+
mask[num_visual:, num_visual:] = causal # text→text: causal
|
| 451 |
+
|
| 452 |
+
# ── Decoder ───────────────────────────────────────────────────────────
|
| 453 |
+
x = self.decoder_blocks(combined, mask=mask, is_causal=False)
|
| 454 |
+
text_out = x[:, num_visual:, :]
|
| 455 |
+
text_out = self.ln_f(text_out)
|
| 456 |
+
logits = self.lm_head(text_out) # (B, T, vocab)
|
| 457 |
+
|
| 458 |
+
# ── Loss (ignore padding index 0) ─────────────────────────────────────
|
| 459 |
+
loss = None
|
| 460 |
+
if text_targets is not None:
|
| 461 |
+
tgt = text_targets[:, :T_clipped]
|
| 462 |
+
loss = F.cross_entropy(
|
| 463 |
+
logits.reshape(B * T_clipped, -1),
|
| 464 |
+
tgt.reshape(B * T_clipped),
|
| 465 |
+
ignore_index=0,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
return logits, loss
|
| 469 |
+
|
| 470 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 471 |
+
# Generation
|
| 472 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 473 |
+
|
| 474 |
+
@torch.no_grad()
|
| 475 |
+
def generate(self, pixel_values, char_to_idx, idx_to_char,
|
| 476 |
+
max_new_tokens=100, temperature=0.8):
|
| 477 |
+
"""
|
| 478 |
+
Autoregressive character-level caption generation (temperature sampling).
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
pixel_values : (1, 3, H, W) pre-processed image tensor
|
| 482 |
+
char_to_idx : character → index mapping
|
| 483 |
+
idx_to_char : index → character mapping
|
| 484 |
+
max_new_tokens : how many characters to generate
|
| 485 |
+
temperature : sampling temperature (0.8 = slightly sharper than uniform)
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
generated_text : str
|
| 489 |
+
"""
|
| 490 |
+
self.eval()
|
| 491 |
+
device = pixel_values.device
|
| 492 |
+
|
| 493 |
+
bos_idx = char_to_idx.get("\n", 0)
|
| 494 |
+
idx_seq = torch.tensor([[bos_idx]], dtype=torch.long, device=device)
|
| 495 |
+
|
| 496 |
+
for _ in range(max_new_tokens):
|
| 497 |
+
# Clip text to block_size — the forward method handles the visual
|
| 498 |
+
# prefix separately, so we only need to limit the text portion.
|
| 499 |
+
idx_cond = idx_seq[:, -self.block_size:]
|
| 500 |
+
logits, _ = self(pixel_values, idx_cond)
|
| 501 |
+
# Take the last time step
|
| 502 |
+
logits_last = logits[:, -1, :] / max(temperature, 1e-5)
|
| 503 |
+
probs = F.softmax(logits_last, dim=-1)
|
| 504 |
+
next_idx = torch.multinomial(probs, num_samples=1)
|
| 505 |
+
idx_seq = torch.cat([idx_seq, next_idx], dim=1)
|
| 506 |
+
|
| 507 |
+
# Decode, skip the leading BOS
|
| 508 |
+
generated = "".join(
|
| 509 |
+
idx_to_char.get(i.item(), "?") for i in idx_seq[0, 1:]
|
| 510 |
+
)
|
| 511 |
+
return generated
|
| 512 |
+
|
| 513 |
+
@torch.no_grad()
|
| 514 |
+
def generate_beam(self, pixel_values, char_to_idx, idx_to_char,
|
| 515 |
+
max_new_tokens=100, num_beams=4, length_penalty=1.0):
|
| 516 |
+
"""
|
| 517 |
+
Beam-search character-level caption generation.
|
| 518 |
+
|
| 519 |
+
At each step we keep the top `num_beams` partial sequences ranked by
|
| 520 |
+
cumulative log-probability (with optional length penalty).
|
| 521 |
+
|
| 522 |
+
Args:
|
| 523 |
+
pixel_values : (1, 3, H, W) image tensor
|
| 524 |
+
char_to_idx : char → idx mapping
|
| 525 |
+
idx_to_char : idx → char mapping
|
| 526 |
+
max_new_tokens : max characters to generate
|
| 527 |
+
num_beams : beam width (1 = greedy)
|
| 528 |
+
length_penalty : >1 favors longer sequences; <1 favors shorter
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
generated_text : str (best beam)
|
| 532 |
+
"""
|
| 533 |
+
self.eval()
|
| 534 |
+
device = pixel_values.device
|
| 535 |
+
|
| 536 |
+
bos_idx = char_to_idx.get("\n", 0)
|
| 537 |
+
# Each beam: (score, token_sequence_tensor)
|
| 538 |
+
beams = [(0.0, torch.tensor([[bos_idx]], dtype=torch.long, device=device))]
|
| 539 |
+
|
| 540 |
+
for _ in range(max_new_tokens):
|
| 541 |
+
candidates = []
|
| 542 |
+
for score, seq in beams:
|
| 543 |
+
idx_cond = seq[:, -self.block_size:]
|
| 544 |
+
logits, _ = self(pixel_values, idx_cond)
|
| 545 |
+
log_probs = F.log_softmax(logits[:, -1, :], dim=-1) # (1, vocab)
|
| 546 |
+
topk_probs, topk_ids = log_probs.topk(num_beams, dim=-1)
|
| 547 |
+
|
| 548 |
+
for k in range(num_beams):
|
| 549 |
+
new_score = score + topk_probs[0, k].item()
|
| 550 |
+
new_seq = torch.cat(
|
| 551 |
+
[seq, topk_ids[:, k:k+1]], dim=1
|
| 552 |
+
)
|
| 553 |
+
candidates.append((new_score, new_seq))
|
| 554 |
+
|
| 555 |
+
# Apply length penalty and keep top beams
|
| 556 |
+
candidates.sort(
|
| 557 |
+
key=lambda x: x[0] / (x[1].shape[1] ** length_penalty),
|
| 558 |
+
reverse=True,
|
| 559 |
+
)
|
| 560 |
+
beams = candidates[:num_beams]
|
| 561 |
+
|
| 562 |
+
best_seq = beams[0][1]
|
| 563 |
+
return "".join(idx_to_char.get(i.item(), "?") for i in best_seq[0, 1:])
|
models/git_tuner.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models/git_tuner.py
|
| 3 |
+
===================
|
| 4 |
+
Baseline 2 — Zero Cross-Attention / Self-Attention Prefix (GIT)
|
| 5 |
+
|
| 6 |
+
Architecture: GIT (Generative Image-to-Text) abandons cross-attention entirely.
|
| 7 |
+
It concatenates image patch embeddings directly in front of the text tokens and
|
| 8 |
+
runs a single causal self-attention Transformer over the combined sequence.
|
| 9 |
+
|
| 10 |
+
There is NO cross-attention block. The model learns to fuse modalities purely
|
| 11 |
+
through self-attention across a unified image+text token sequence. This makes
|
| 12 |
+
the ablation masks work differently — we control which image tokens are
|
| 13 |
+
prepended to the sequence rather than using encoder_attention_mask.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import torch
|
| 18 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_git_model(cfg, device):
|
| 22 |
+
"""
|
| 23 |
+
Load microsoft/git-base-coco with gradient checkpointing.
|
| 24 |
+
GIT uses AutoModelForCausalLM interface.
|
| 25 |
+
"""
|
| 26 |
+
model_id = cfg.git_model_id
|
| 27 |
+
|
| 28 |
+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
|
| 29 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
model.gradient_checkpointing_enable()
|
| 33 |
+
print("✅ Gradient checkpointing enabled (GIT)")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"⚠️ Gradient checkpointing failed: {e}")
|
| 36 |
+
|
| 37 |
+
model.config.use_cache = False
|
| 38 |
+
model.to(device)
|
| 39 |
+
|
| 40 |
+
n_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 41 |
+
print(f"✅ GIT loaded on {device}: {model_id} ({n_params:.1f}M params)")
|
| 42 |
+
return model, processor
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def generate_caption(model, processor, image_pil, device,
|
| 46 |
+
max_new_tokens=32, num_beams=4):
|
| 47 |
+
"""
|
| 48 |
+
Generate a caption for a single PIL image using GIT.
|
| 49 |
+
|
| 50 |
+
Note: GIT has no encoder_attention_mask concept (no cross-attention).
|
| 51 |
+
Ablation for GIT is handled upstream by modifying the pixel_values
|
| 52 |
+
(e.g., masking image regions) before passing to the model, OR by
|
| 53 |
+
returning a note that GIT is not compatible with encoder-mask ablations.
|
| 54 |
+
"""
|
| 55 |
+
model.eval()
|
| 56 |
+
inputs = processor(images=image_pil, return_tensors="pt").to(device)
|
| 57 |
+
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
output_ids = model.generate(
|
| 60 |
+
**inputs,
|
| 61 |
+
max_new_tokens=max_new_tokens,
|
| 62 |
+
num_beams=num_beams,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
caption = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 66 |
+
return caption
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def save_ckpt(model, processor, optimizer, scheduler,
|
| 70 |
+
step, epoch, cfg_dict, path):
|
| 71 |
+
os.makedirs(path, exist_ok=True)
|
| 72 |
+
model.save_pretrained(path)
|
| 73 |
+
processor.save_pretrained(path)
|
| 74 |
+
|
| 75 |
+
torch.save(
|
| 76 |
+
{
|
| 77 |
+
"step": step,
|
| 78 |
+
"epoch": epoch,
|
| 79 |
+
"optimizer": optimizer.state_dict() if optimizer else None,
|
| 80 |
+
"scheduler": scheduler.state_dict() if scheduler else None,
|
| 81 |
+
"cfg": cfg_dict,
|
| 82 |
+
},
|
| 83 |
+
os.path.join(path, "train_state.pt"),
|
| 84 |
+
)
|
| 85 |
+
print(f"✅ GIT checkpoint saved: {path}")
|
models/vit_gpt2_tuner.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models/vit_gpt2_tuner.py
|
| 3 |
+
========================
|
| 4 |
+
Baseline 1 — Standard Cross-Attention (ViT-GPT2)
|
| 5 |
+
|
| 6 |
+
Architecture: Every generated text token in the GPT-2 decoder attends to ALL
|
| 7 |
+
197 ViT patch embeddings via explicit cross-attention blocks injected between
|
| 8 |
+
each GPT-2 self-attention layer.
|
| 9 |
+
|
| 10 |
+
This is the "brute-force" cross-attention baseline: no restrictions, no pooling.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import torch
|
| 15 |
+
from transformers import (
|
| 16 |
+
VisionEncoderDecoderModel,
|
| 17 |
+
ViTImageProcessor,
|
| 18 |
+
AutoTokenizer,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_vit_gpt2_model(cfg, device):
|
| 23 |
+
"""
|
| 24 |
+
Load the VisionEncoderDecoderModel (ViT-GPT2) with:
|
| 25 |
+
- Gradient checkpointing enabled
|
| 26 |
+
- use_cache=False (required with grad checkpointing)
|
| 27 |
+
- Proper pad/bos/eos tokens set for GPT-2
|
| 28 |
+
"""
|
| 29 |
+
model_id = cfg.vit_gpt2_model_id
|
| 30 |
+
|
| 31 |
+
processor = ViTImageProcessor.from_pretrained(model_id, use_fast=True)
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 33 |
+
|
| 34 |
+
# GPT-2 has no pad token by default — use eos as pad
|
| 35 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 36 |
+
|
| 37 |
+
model = VisionEncoderDecoderModel.from_pretrained(model_id)
|
| 38 |
+
model.config.decoder_start_token_id = tokenizer.bos_token_id
|
| 39 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 40 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 41 |
+
|
| 42 |
+
# Memory optimizations
|
| 43 |
+
try:
|
| 44 |
+
model.gradient_checkpointing_enable()
|
| 45 |
+
print("✅ Gradient checkpointing enabled (ViT-GPT2)")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"⚠️ Gradient checkpointing failed: {e}")
|
| 48 |
+
|
| 49 |
+
model.config.use_cache = False
|
| 50 |
+
|
| 51 |
+
# Resize images to cfg.image_size
|
| 52 |
+
try:
|
| 53 |
+
processor.size = {"height": cfg.image_size, "width": cfg.image_size}
|
| 54 |
+
print(f"✅ Image size set to {cfg.image_size}px")
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"⚠️ Could not set image size: {e}")
|
| 57 |
+
|
| 58 |
+
model.to(device)
|
| 59 |
+
n_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 60 |
+
print(f"✅ ViT-GPT2 loaded on {device}: {model_id} ({n_params:.1f}M params)")
|
| 61 |
+
|
| 62 |
+
return model, processor, tokenizer
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def generate_caption(model, processor, tokenizer, image_pil, device,
|
| 66 |
+
max_new_tokens=32, num_beams=4,
|
| 67 |
+
encoder_attention_mask=None):
|
| 68 |
+
"""
|
| 69 |
+
Generate a caption for a single PIL image.
|
| 70 |
+
|
| 71 |
+
encoder_attention_mask: (1, num_patches) allows ablation-mode masking.
|
| 72 |
+
If None, defaults to full attention (all 1s).
|
| 73 |
+
"""
|
| 74 |
+
model.eval()
|
| 75 |
+
inputs = processor(images=image_pil, return_tensors="pt").to(device)
|
| 76 |
+
pixel_values = inputs["pixel_values"]
|
| 77 |
+
|
| 78 |
+
gen_kwargs = dict(
|
| 79 |
+
pixel_values=pixel_values,
|
| 80 |
+
max_new_tokens=max_new_tokens,
|
| 81 |
+
num_beams=num_beams,
|
| 82 |
+
)
|
| 83 |
+
if encoder_attention_mask is not None:
|
| 84 |
+
gen_kwargs["attention_mask"] = encoder_attention_mask.to(device)
|
| 85 |
+
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
output_ids = model.generate(**gen_kwargs)
|
| 88 |
+
|
| 89 |
+
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 90 |
+
return caption
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def save_ckpt(model, processor, tokenizer, optimizer, scheduler,
|
| 94 |
+
step, epoch, cfg_dict, path):
|
| 95 |
+
os.makedirs(path, exist_ok=True)
|
| 96 |
+
model.save_pretrained(path)
|
| 97 |
+
processor.save_pretrained(path)
|
| 98 |
+
tokenizer.save_pretrained(path)
|
| 99 |
+
|
| 100 |
+
torch.save(
|
| 101 |
+
{
|
| 102 |
+
"step": step,
|
| 103 |
+
"epoch": epoch,
|
| 104 |
+
"optimizer": optimizer.state_dict() if optimizer else None,
|
| 105 |
+
"scheduler": scheduler.state_dict() if scheduler else None,
|
| 106 |
+
"cfg": cfg_dict,
|
| 107 |
+
},
|
| 108 |
+
os.path.join(path, "train_state.pt"),
|
| 109 |
+
)
|
| 110 |
+
print(f"✅ ViT-GPT2 checkpoint saved: {path}")
|
project_02_DS
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit a5ea2c20321ecd6767352a2393f0bc58a8a9f059
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
torchaudio
|
| 4 |
+
transformers>=4.37.0
|
| 5 |
+
datasets
|
| 6 |
+
aiohttp
|
| 7 |
+
streamlit
|
| 8 |
+
numpy
|
| 9 |
+
Pillow
|
| 10 |
+
tqdm
|
| 11 |
+
accelerate
|
| 12 |
+
sentencepiece
|
| 13 |
+
pycocoevalcap
|
shakespeare_transformer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:652c085bf4c7275182fe726a38d3034aeaf1d67d8dc93f8c014976b2408f7ce5
|
| 3 |
+
size 74253331
|
simplified_overview_vlm_image_captioning_project.md
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How I Built a System That Teaches Computers to Describe Photographs
|
| 2 |
+
|
| 3 |
+
**A non-technical overview of the VLM Caption Lab project**
|
| 4 |
+
*Author: Manoj Kumar | 4 March 2026*
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## What Is This Project About?
|
| 9 |
+
|
| 10 |
+
Imagine showing a photograph to a friend and asking them to describe it in one sentence. They might say, *"A man in a suit standing in front of a tree,"* or *"A tennis match in a large arena with a crowd watching."* For us, this is effortless — our brains process the entire image, identify the objects, understand the scene, and produce a fluent sentence in under a second.
|
| 11 |
+
|
| 12 |
+
For a computer, this is remarkably difficult. The technical name for this task is **"image captioning,"** and it lives at the crossroads of two hard problems: understanding what is in an image (computer vision) and writing grammatically correct, meaningful sentences (natural language generation).
|
| 13 |
+
|
| 14 |
+
This project explores that challenge — **but I did not just build one system. I built and compared four of them,** each with a fundamentally different approach to the core problem of looking at the image while writing about it.
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## The Four Models I Built (And Why They Are Different)
|
| 19 |
+
|
| 20 |
+
Think of image captioning like a person looking at a painting while narrating what they see into a microphone. The four models I compared differ in **how the person glances at the painting while they talk.**
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
### 🔵 Model 1: BLIP — The Selective Glancer
|
| 25 |
+
|
| 26 |
+
**How it works :** BLIP is like a narrator who has trained themselves to only glance at the painting when they need to. When they are saying generic words like "a" or "the" or "is," they just focus on their own sentence. When they need to mention something specific — like "bicycle" or "standing" — they look up at the painting to confirm what they see.
|
| 27 |
+
|
| 28 |
+
**Why this is smart:** Most words in a sentence are structural, not visual. There is no need to look at the image to say "the" or "in front of." BLIP learns when to look and when not to, which prevents it from getting confused by too much visual information.
|
| 29 |
+
|
| 30 |
+
**Size:** 224 million parameters
|
| 31 |
+
**Best CIDEr score:** **0.62** (with optimized settings)
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
### Model 2: ViT-GPT2 — The Constant Starer
|
| 36 |
+
|
| 37 |
+
**How it works in plain English:** ViT-GPT2 takes the opposite approach — for every single word, it stares at the entire painting. Writing "a"? Look at the whole image. Writing "dog"? Look at the whole image. Writing "the"? Still looking at the whole image.
|
| 38 |
+
|
| 39 |
+
**Why this still works:** Even though it is wasteful, staring at everything guarantees the model never misses any visual detail. The downside is that this constant stream of visual information can sometimes confuse the language part of the model.
|
| 40 |
+
|
| 41 |
+
**Size:** 239 million parameters
|
| 42 |
+
**Typical CIDEr score:** ~0.55
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
### Model 3: GIT — The Memorizer
|
| 47 |
+
|
| 48 |
+
**How it works in plain English:** GIT does something clever — instead of switching between looking at the painting and writing words, it first memorizes the entire painting and then writes the caption purely from memory.
|
| 49 |
+
|
| 50 |
+
In technical terms, GIT converts the image into a set of structured "memory notes" and places them at the beginning of its sentence. Then it processes everything — image memories and text — in one continuous stream. There is no separate "looking at the painting" step.
|
| 51 |
+
|
| 52 |
+
**Why this is elegant:** It is simpler and faster because it does not need the extra machinery for looking back and forth between image and text. The entire intelligence is in one unified processing step.
|
| 53 |
+
|
| 54 |
+
**Size:** 177 million parameters (smallest of the four)
|
| 55 |
+
**Typical CIDEr score:** ~0.54
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
### Model 4: Custom VLM — The Shakespeare Bot Learning Modern English
|
| 60 |
+
|
| 61 |
+
**How it works in plain English:** This is the most experimental model, and the one **I built entirely from scratch.** Imagine a narrator who grew up reading only Shakespeare and has never seen a photograph before. You give them a pair of glasses (a visual encoder — something that can look at images) and a translator (a small bridging network) and ask them to describe modern photographs.
|
| 62 |
+
|
| 63 |
+
The "Shakespeare bot" is a text generator I had previously trained on the complete works of Shakespeare. It knows English grammar and sentence structure — but in Elizabethan English. The challenge was teaching it to (a) understand images through the "glasses" and (b) speak in modern, descriptive English instead of iambic pentameter.
|
| 64 |
+
|
| 65 |
+
**Why I built this:** To understand what minimum set of components you need to make a functioning vision-language model. Instead of downloading a ready-made model with billions of parameters, I wanted to see if I could glue together a vision model and a text model with just a small trainable "bridge" in between.
|
| 66 |
+
|
| 67 |
+
**Size:** 103 million parameters total, but only **16.2 million are trainable** (the rest are frozen)
|
| 68 |
+
**Best CIDEr score:** **0.2863** (still learning, but it works!)
|
| 69 |
+
|
| 70 |
+
---
|
| 71 |
+
|
| 72 |
+
## What Is CIDEr? (The Score We Use to Measure Quality)
|
| 73 |
+
|
| 74 |
+
Throughout this summary, I mention "CIDEr scores." Here is what they mean:
|
| 75 |
+
|
| 76 |
+
**CIDEr** stands for "Consensus-based Image Description Evaluation." In simple terms, it compares the caption our model generates to **five human-written descriptions** of the same image.
|
| 77 |
+
|
| 78 |
+
- It counts how many meaningful words overlap between the model's caption and the human captions
|
| 79 |
+
- It gives more weight to descriptive words (like "bicycle" or "stadium") than common words (like "the" or "is")
|
| 80 |
+
- **A higher score means the computer's description sounds more like what a human would write**
|
| 81 |
+
|
| 82 |
+
| CIDEr Score | What It Means |
|
| 83 |
+
|---|---|
|
| 84 |
+
| 0.00 | Completely wrong — no overlap with human descriptions |
|
| 85 |
+
| 0.20–0.30 | Early stage — some correct words, but the sentence may be awkward |
|
| 86 |
+
| 0.50–0.60 | Good — clearly related to the image, mostly sensible |
|
| 87 |
+
| 0.80–1.00 | Excellent — almost indistinguishable from a human caption |
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## The Custom Model Story: A Journey of Debugging and Discovery
|
| 92 |
+
|
| 93 |
+
This is the part of the project I am most proud of, because it taught me the most about how machine learning actually works in practice — not just in theory, but when things go wrong.
|
| 94 |
+
|
| 95 |
+
### Chapter 1: "Why Is It Speaking Gibberish?"
|
| 96 |
+
|
| 97 |
+
My first attempt at the Custom VLM produced output like this:
|
| 98 |
+
|
| 99 |
+
> *"iGiiiiiGiviqiGqiFliqiGidlidiliGilFGilqiiiqiiiiGii"*
|
| 100 |
+
|
| 101 |
+
That is not English. That is not even Shakespeare. It is random noise.
|
| 102 |
+
|
| 103 |
+
**The problem:** The connection between the "glasses" (the image encoder) and the "brain" (the Shakespeare text generator) was too weak. I was using a single mathematical transformation to convert visual information into text information. Think of it like trying to translate a painting into a poem by only measuring the canvas size — you are missing all the important details.
|
| 104 |
+
|
| 105 |
+
**CIDEr score at this stage: 0.0000 — literally zero.**
|
| 106 |
+
|
| 107 |
+
### Chapter 2: "Better Connection, But Still Broken"
|
| 108 |
+
|
| 109 |
+
I upgraded the connection to a more powerful two-layer network. This is like upgrading from a basic dictionary to a bilingual tutor who understands context. The training measurements started improving — the numbers were going down, which normally means the model is learning.
|
| 110 |
+
|
| 111 |
+
But the output was still gibberish.
|
| 112 |
+
|
| 113 |
+
After days of investigation, I found the real problem — and it was a doozy:
|
| 114 |
+
|
| 115 |
+
> **When I loaded the Shakespeare brain into the model, 97% of the brain weights failed to load. Silently. No error message. No warning. The software just said "everything is fine" and moved on.**
|
| 116 |
+
|
| 117 |
+
My model had been running on a **randomly initialized brain** — essentially trying to learn language from scratch while simultaneously trying to learn to describe images. Imagine asking someone with amnesia to write poetry about something they've never seen. That's what my model was trying to do.
|
| 118 |
+
|
| 119 |
+
**Why did this happen?** The two models (Shakespeare and my VLM) stored their internal knowledge in slightly different formats. It is like trying to load a Word document into Excel — both are files, but the internal structure is completely different. The software saw the mismatched formats and just... skipped everything. Without telling me.
|
| 120 |
+
|
| 121 |
+
### Chapter 3: "It Finally Speaks!"
|
| 122 |
+
|
| 123 |
+
The fix required three things:
|
| 124 |
+
1. **Match the formats** — Make the new model structure identical to the Shakespeare model's structure (8 layers, 8 attention heads, matching dimensions)
|
| 125 |
+
2. **Translate the weights** — Write custom code to convert the Shakespeare data from one format to another
|
| 126 |
+
3. **Let the brain learn** — Instead of freezing the Shakespeare knowledge, let the model slowly adapt from old English to modern descriptions
|
| 127 |
+
|
| 128 |
+
**The result was immediate.** From the very first training session after the fix, the improvement was dramatic:
|
| 129 |
+
|
| 130 |
+
> Before fix: *"iGiiiiiGiviqiGqiFliqiGidlidiliGilFGilqiiiqiiiiGii"* (CIDEr: 0.0000)
|
| 131 |
+
> After fix: *"man in the bluess and white play with and a pizza"* (CIDEr: 0.2863)
|
| 132 |
+
|
| 133 |
+
Not perfect. Not even grammatically correct. But it is **clearly English**, it is **clearly attempting to describe an image**, and it went from zero to something meaningful. The word "man" appeared because the image showed a man. The model learned real English words and connected them to visual concepts.
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## What We Tested: The Three Experiments
|
| 138 |
+
|
| 139 |
+
### Experiment 1: "Can We Cover Part of the Image?"
|
| 140 |
+
|
| 141 |
+
I blocked parts of the image from the model and measured whether the captions got worse. The results were genuinely surprising:
|
| 142 |
+
|
| 143 |
+
| What We Did | Effect on Caption Quality |
|
| 144 |
+
|---|---|
|
| 145 |
+
| Showed the **full image** | Baseline quality (CIDEr: 0.5371) |
|
| 146 |
+
| **Hid 50%** of the image randomly | **No change at all** (CIDEr: 0.5371) |
|
| 147 |
+
| Showed **only the center** (removed background) | **No change at all** (CIDEr: 0.5371) |
|
| 148 |
+
| **Compressed everything** into one tiny summary | **Complete failure** (CIDEr: 0.0008 — a 99.8% drop) |
|
| 149 |
+
|
| 150 |
+
**What this teaches us:** Images contain a lot of redundant information. You can throw away half the visual data and still get perfectly good captions. But if you compress everything into a single summary, you lose the information about **where things are** relative to each other — and that spatial information turns out to be essential for describing a scene.
|
| 151 |
+
|
| 152 |
+
### Experiment 2: "What Settings Produce the Best Captions?"
|
| 153 |
+
|
| 154 |
+
When a model generates a caption, it uses a search algorithm that considers multiple possible sentences and picks the best one. I tested **18 different combinations** of settings and found:
|
| 155 |
+
|
| 156 |
+
- **Considering more candidate sentences (10 instead of 3) helped significantly** — about 13% improvement
|
| 157 |
+
- **Slightly encouraging shorter captions helped** — models tend to ramble when given too much freedom
|
| 158 |
+
- **Best combination found: CIDEr score of 0.6199** (up from 0.48 with the worst settings)
|
| 159 |
+
|
| 160 |
+
### Experiment 3: "Does Caption Quality During Training Matter?"
|
| 161 |
+
|
| 162 |
+
I compared different strategies for selecting which human captions to show the model during training:
|
| 163 |
+
|
| 164 |
+
| Strategy | CIDEr Score |
|
| 165 |
+
|---|---|
|
| 166 |
+
| Use any random caption | **0.6359** ← best for this clean dataset |
|
| 167 |
+
| Use only short captions (≤ 9 words) | 0.6016 |
|
| 168 |
+
| Use only medium-length captions (5–25 words) | 0.5877 |
|
| 169 |
+
| Use only long captions (≥ 12 words) | 0.5389 |
|
| 170 |
+
|
| 171 |
+
**Bottom line:** For this particular dataset (which is already well-curated), using raw unfiltered captions works best. But filtering is recommended for noisier datasets.
|
| 172 |
+
|
| 173 |
+
---
|
| 174 |
+
|
| 175 |
+
## The Interactive Demo
|
| 176 |
+
|
| 177 |
+
I built a web application where anyone can try the models themselves:
|
| 178 |
+
|
| 179 |
+
- **Upload any photo** and get a caption from any of the four models
|
| 180 |
+
- **Compare all four models** side by side on the same image — see how each one describes the same picture differently
|
| 181 |
+
- **Switch between pre-trained and fine-tuned** versions of each model
|
| 182 |
+
- **Adjust generation settings** — control how the model searches for the best caption
|
| 183 |
+
- **View experiment results** — browse all the findings from the three experiments
|
| 184 |
+
|
| 185 |
+
Every generated caption goes through a **safety filter** before being shown, because AI models can occasionally produce inappropriate descriptions. The filter uses a toxicity detection model to catch and block offensive content.
|
| 186 |
+
|
| 187 |
+
---
|
| 188 |
+
|
| 189 |
+
## Summary of Results
|
| 190 |
+
|
| 191 |
+
| Model | Approach | CIDEr Score | Key Strength |
|
| 192 |
+
|---|---|---|---|
|
| 193 |
+
| **BLIP** | Selective looking | **0.62** (best settings) | Best quality — knows when to look vs. when to focus on grammar |
|
| 194 |
+
| **ViT-GPT2** | Constant looking | ~0.55 | Strong baseline — full visual access at all times |
|
| 195 |
+
| **GIT** | Memory-based | ~0.54 | Elegant and efficient — no cross-attention needed at all |
|
| 196 |
+
| **Custom VLM** | Built from scratch | **0.29** | Proof of concept — works despite tiny vocabulary and Shakespeare origins |
|
| 197 |
+
|
| 198 |
+
---
|
| 199 |
+
|
| 200 |
+
## What I Actually Learned
|
| 201 |
+
|
| 202 |
+
1. **There is no single best way to connect vision and language.** BLIP's selective attention works best overall, but GIT's simpler approach is surprisingly competitive — proving that you do not always need complex mechanisms to solve complex problems.
|
| 203 |
+
|
| 204 |
+
2. **Silent failures are the most dangerous bugs in machine learning.** The most time-consuming problem in this project was a weight-loading failure that produced zero error messages. The model ran, the loss decreased, everything looked normal — but 97% of the model was running on random noise. I now always verify that weights loaded correctly.
|
| 205 |
+
|
| 206 |
+
3. **The number your model optimizes during training is not necessarily the number that tells you if it is doing a good job.** Training loss went down steadily, but the captions were still gibberish. Only when I started measuring CIDEr (actual caption quality) did I understand what was really happening.
|
| 207 |
+
|
| 208 |
+
4. **Small models can learn big tasks with the right approach.** The Custom VLM has only 16.2 million trainable parameters — roughly 1/15th the size of BLIP — yet it learned to produce recognizable English descriptions of images by building on existing Shakespeare knowledge.
|
| 209 |
+
|
| 210 |
+
5. **Images are surprisingly redundant.** You can literally hide half the image and the model generates identical captions. But structure matters — where objects are relative to each other is more important than being able to see every pixel.
|
| 211 |
+
|
| 212 |
+
---
|
| 213 |
+
|
| 214 |
+
## What Could Be Improved Next
|
| 215 |
+
|
| 216 |
+
If I continue this project, the highest-impact improvements would be:
|
| 217 |
+
|
| 218 |
+
- **Better vocabulary:** The Custom VLM currently spells everything letter-by-letter (65 characters). Switching to a word-piece vocabulary (thousands of tokens) would dramatically reduce the difficulty.
|
| 219 |
+
- **Stronger language foundation:** Replacing the Shakespeare decoder with a modern language model like GPT-2 would give the model native modern English instead of having to translate from Elizabethan.
|
| 220 |
+
- **More training data:** We currently use only 18% of the available dataset images.
|
| 221 |
+
|
| 222 |
+
---
|
| 223 |
+
|
| 224 |
+
*Project by Manoj Kumar, March 2026*
|
train.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train.py
|
| 3 |
+
========
|
| 4 |
+
Unified training entrypoint for all VLM architectures:
|
| 5 |
+
--model blip → Fine-tune BLIP (Multimodal Mixture Attention)
|
| 6 |
+
--model vit_gpt2 → Fine-tune ViT-GPT2 (Standard Cross-Attention)
|
| 7 |
+
--model git → Fine-tune GIT (Zero Cross-Attention / Self-Attention Prefix)
|
| 8 |
+
--model custom → Train visual_projection only (Visual Prefix-Tuning)
|
| 9 |
+
|
| 10 |
+
Checkpoint Strategy:
|
| 11 |
+
All outputs are saved under outputs/{model_name}/:
|
| 12 |
+
- latest/ — overwritten every epoch (always the most recent state)
|
| 13 |
+
- best/ — overwritten only when validation loss improves
|
| 14 |
+
|
| 15 |
+
Optimized for Apple Silicon MPS backend with:
|
| 16 |
+
- Gradient accumulation
|
| 17 |
+
- Gradient checkpointing
|
| 18 |
+
- Cosine LR scheduler with linear warmup
|
| 19 |
+
- MPS-safe DataLoader settings (num_workers=0, pin_memory=False)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import math
|
| 24 |
+
import time
|
| 25 |
+
import os
|
| 26 |
+
import torch
|
| 27 |
+
from torch.optim import AdamW
|
| 28 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 29 |
+
from tqdm.auto import tqdm
|
| 30 |
+
|
| 31 |
+
from config import CFG
|
| 32 |
+
from data_prep import get_dataloaders, get_dataloaders_for_model, get_custom_vlm_dataloader
|
| 33 |
+
from models.blip_tuner import get_blip_model, save_ckpt as blip_save, generate_with_mask
|
| 34 |
+
from models.vit_gpt2_tuner import get_vit_gpt2_model, save_ckpt as vit_gpt2_save
|
| 35 |
+
from models.git_tuner import get_git_model, save_ckpt as git_save
|
| 36 |
+
from models.custom_vlm import CustomVLM, build_char_vocab
|
| 37 |
+
from pycocoevalcap.cider.cider import Cider
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_device():
|
| 41 |
+
if torch.backends.mps.is_available():
|
| 42 |
+
return torch.device("mps")
|
| 43 |
+
elif torch.cuda.is_available():
|
| 44 |
+
return torch.device("cuda")
|
| 45 |
+
return torch.device("cpu")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_output_paths(cfg, model_name: str):
|
| 49 |
+
"""
|
| 50 |
+
Return (latest_dir, best_dir) for a given model.
|
| 51 |
+
Creates directories if they don't exist.
|
| 52 |
+
"""
|
| 53 |
+
base = os.path.join(cfg.output_root, model_name)
|
| 54 |
+
latest = os.path.join(base, "latest")
|
| 55 |
+
best = os.path.join(base, "best")
|
| 56 |
+
os.makedirs(latest, exist_ok=True)
|
| 57 |
+
os.makedirs(best, exist_ok=True)
|
| 58 |
+
return latest, best
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 62 |
+
# Shared Training Loop
|
| 63 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 64 |
+
|
| 65 |
+
def _generate_hf_captions(model, batch, model_name, device,
|
| 66 |
+
processor=None, tokenizer=None):
|
| 67 |
+
"""
|
| 68 |
+
Generate captions for a batch of images using the appropriate HuggingFace model.
|
| 69 |
+
Returns (predictions: list[str], ground_truths: list[str]).
|
| 70 |
+
"""
|
| 71 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 72 |
+
|
| 73 |
+
if model_name == "BLIP":
|
| 74 |
+
B = pixel_values.shape[0]
|
| 75 |
+
mask = torch.ones(B, 197, dtype=torch.long, device=device)
|
| 76 |
+
decoded = generate_with_mask(
|
| 77 |
+
model, processor, device=device,
|
| 78 |
+
pixel_values=pixel_values,
|
| 79 |
+
encoder_attention_mask=mask,
|
| 80 |
+
max_new_tokens=32, num_beams=4,
|
| 81 |
+
)
|
| 82 |
+
preds = decoded # generate_with_mask already returns decoded strings
|
| 83 |
+
labels = batch["labels"].clone()
|
| 84 |
+
gt_texts = processor.batch_decode(labels, skip_special_tokens=True)
|
| 85 |
+
|
| 86 |
+
elif model_name == "VIT_GPT2":
|
| 87 |
+
out = model.generate(
|
| 88 |
+
pixel_values=pixel_values, num_beams=4, max_new_tokens=32,
|
| 89 |
+
)
|
| 90 |
+
preds = [tokenizer.decode(ids, skip_special_tokens=True) for ids in out]
|
| 91 |
+
labels = batch["labels"].clone()
|
| 92 |
+
labels[labels == -100] = tokenizer.pad_token_id
|
| 93 |
+
gt_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 94 |
+
|
| 95 |
+
elif model_name == "GIT":
|
| 96 |
+
inputs = {k: v.to(device) for k, v in batch.items()
|
| 97 |
+
if k in ("pixel_values", "input_ids", "attention_mask")}
|
| 98 |
+
out = model.generate(**inputs, num_beams=4, max_new_tokens=32)
|
| 99 |
+
preds = processor.batch_decode(out, skip_special_tokens=True)
|
| 100 |
+
labels = batch["labels"].clone()
|
| 101 |
+
labels[labels == -100] = processor.tokenizer.pad_token_id
|
| 102 |
+
gt_texts = processor.batch_decode(labels, skip_special_tokens=True)
|
| 103 |
+
else:
|
| 104 |
+
return [], []
|
| 105 |
+
|
| 106 |
+
return preds, gt_texts
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def run_training_loop(model, optimizer, scheduler, train_loader, val_loader,
|
| 110 |
+
cfg, save_latest_fn, save_best_fn, model_name,
|
| 111 |
+
processor=None, tokenizer=None):
|
| 112 |
+
"""
|
| 113 |
+
Shared gradient-accumulation training loop for all HuggingFace models.
|
| 114 |
+
|
| 115 |
+
Now includes per-epoch:
|
| 116 |
+
- Validation loss
|
| 117 |
+
- CIDEr scoring via greedy generation
|
| 118 |
+
- CIDEr-based checkpointing (saves best/ based on highest CIDEr)
|
| 119 |
+
"""
|
| 120 |
+
device = get_device()
|
| 121 |
+
model.train()
|
| 122 |
+
global_step = 0
|
| 123 |
+
best_cider = -1.0
|
| 124 |
+
t0 = time.time()
|
| 125 |
+
|
| 126 |
+
for epoch in range(1, cfg.epochs + 1):
|
| 127 |
+
model.train()
|
| 128 |
+
pbar = tqdm(train_loader, desc=f"[{model_name}] Epoch {epoch}/{cfg.epochs}")
|
| 129 |
+
running_loss = 0.0
|
| 130 |
+
epoch_loss_sum = 0.0
|
| 131 |
+
epoch_batches = 0
|
| 132 |
+
optimizer.zero_grad(set_to_none=True)
|
| 133 |
+
|
| 134 |
+
for i, batch in enumerate(pbar, start=1):
|
| 135 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 136 |
+
|
| 137 |
+
out = model(**batch)
|
| 138 |
+
loss = out.loss / cfg.grad_accum
|
| 139 |
+
loss.backward()
|
| 140 |
+
running_loss += loss.item()
|
| 141 |
+
epoch_loss_sum += out.loss.item()
|
| 142 |
+
epoch_batches += 1
|
| 143 |
+
|
| 144 |
+
if i % cfg.grad_accum == 0 or i == len(train_loader):
|
| 145 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
|
| 146 |
+
optimizer.step()
|
| 147 |
+
scheduler.step()
|
| 148 |
+
optimizer.zero_grad(set_to_none=True)
|
| 149 |
+
global_step += 1
|
| 150 |
+
|
| 151 |
+
if global_step % cfg.log_every == 0:
|
| 152 |
+
avg = running_loss / cfg.log_every
|
| 153 |
+
running_loss = 0.0
|
| 154 |
+
pbar.set_postfix({"loss": f"{avg:.4f}",
|
| 155 |
+
"lr": f"{scheduler.get_last_lr()[0]:.2e}"})
|
| 156 |
+
|
| 157 |
+
# End of epoch — training metrics
|
| 158 |
+
epoch_avg_loss = epoch_loss_sum / max(epoch_batches, 1)
|
| 159 |
+
print(f"\n📊 Epoch {epoch}/{cfg.epochs} avg loss (Train): {epoch_avg_loss:.4f}")
|
| 160 |
+
|
| 161 |
+
# ── Validation Loop: Loss + CIDEr ────────────────────────────────────
|
| 162 |
+
model.eval()
|
| 163 |
+
val_loss_sum = 0.0
|
| 164 |
+
val_batches = 0
|
| 165 |
+
gts, res = {}, {}
|
| 166 |
+
max_eval_batches = 10
|
| 167 |
+
print(" 🔍 Running Validation (Loss & CIDEr)...")
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
for i, batch in enumerate(val_loader):
|
| 171 |
+
if i >= max_eval_batches:
|
| 172 |
+
break
|
| 173 |
+
|
| 174 |
+
batch_d = {k: v.to(device) for k, v in batch.items()}
|
| 175 |
+
|
| 176 |
+
# 1. Validation loss
|
| 177 |
+
out = model(**batch_d)
|
| 178 |
+
val_loss_sum += out.loss.item()
|
| 179 |
+
val_batches += 1
|
| 180 |
+
|
| 181 |
+
# 2. Generate captions for CIDEr
|
| 182 |
+
preds, gt_texts = _generate_hf_captions(
|
| 183 |
+
model, batch, model_name, device,
|
| 184 |
+
processor=processor, tokenizer=tokenizer,
|
| 185 |
+
)
|
| 186 |
+
for j, (p, g) in enumerate(zip(preds, gt_texts)):
|
| 187 |
+
k = f"{epoch}_{i}_{j}"
|
| 188 |
+
res[k] = [p]
|
| 189 |
+
gts[k] = [g]
|
| 190 |
+
|
| 191 |
+
val_avg_loss = val_loss_sum / max(val_batches, 1)
|
| 192 |
+
print(f" 📉 Validation Loss: {val_avg_loss:.4f}")
|
| 193 |
+
|
| 194 |
+
# Compute CIDEr
|
| 195 |
+
cider_score = 0.0
|
| 196 |
+
if gts:
|
| 197 |
+
scorer = Cider()
|
| 198 |
+
cider_score, _ = scorer.compute_score(gts, res)
|
| 199 |
+
print(f" 🎯 Validation CIDEr: {cider_score:.4f}")
|
| 200 |
+
|
| 201 |
+
# Save latest checkpoint
|
| 202 |
+
save_latest_fn(step=global_step, epoch=epoch)
|
| 203 |
+
print(f" 💾 Saved → latest/")
|
| 204 |
+
|
| 205 |
+
# Save best based on CIDEr score
|
| 206 |
+
if cider_score > best_cider:
|
| 207 |
+
best_cider = cider_score
|
| 208 |
+
save_best_fn(step=global_step, epoch=epoch)
|
| 209 |
+
print(f" 🏆 New best CIDEr (score={best_cider:.4f}) → best/")
|
| 210 |
+
|
| 211 |
+
elapsed = (time.time() - t0) / 60.0
|
| 212 |
+
print(f"\n✅ {model_name} training complete in {elapsed:.2f} minutes")
|
| 213 |
+
print(f" Best validation CIDEr: {best_cider:.4f}")
|
| 214 |
+
return global_step
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 218 |
+
# Custom VLM Training (projection-only)
|
| 219 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 220 |
+
|
| 221 |
+
def train_custom_vlm(cfg, device):
|
| 222 |
+
print("📖 Loading Shakespeare corpus for character vocabulary...")
|
| 223 |
+
with open(cfg.shakespeare_file, "r", encoding="utf-8") as f:
|
| 224 |
+
text = f.read()
|
| 225 |
+
_, char_to_idx, idx_to_char, vocab_size = build_char_vocab(text)
|
| 226 |
+
print(f"✅ Vocabulary size: {vocab_size} characters")
|
| 227 |
+
|
| 228 |
+
model = CustomVLM(
|
| 229 |
+
vocab_size=vocab_size,
|
| 230 |
+
text_embed_dim=cfg.text_embed_dim,
|
| 231 |
+
n_heads=cfg.n_heads,
|
| 232 |
+
n_layers=cfg.n_layers,
|
| 233 |
+
block_size=cfg.block_size,
|
| 234 |
+
dropout=cfg.dropout,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ── Load pre-trained Shakespeare decoder weights (CRITICAL) ──────────────
|
| 238 |
+
shakespeare_path = getattr(cfg, "shakespeare_weights_path",
|
| 239 |
+
"./shakespeare_transformer.pt")
|
| 240 |
+
if os.path.exists(shakespeare_path):
|
| 241 |
+
model.load_shakespeare_weights(shakespeare_path)
|
| 242 |
+
print(f"✅ Shakespeare decoder weights loaded from {shakespeare_path}")
|
| 243 |
+
else:
|
| 244 |
+
print(f"⚠️ shakespeare_transformer.pt not found at {shakespeare_path}")
|
| 245 |
+
print(" Training with randomly initialized decoder (significantly worse).")
|
| 246 |
+
|
| 247 |
+
model.unfreeze_decoder()
|
| 248 |
+
model.to(device)
|
| 249 |
+
|
| 250 |
+
n_train = model.trainable_params()
|
| 251 |
+
n_total = sum(p.numel() for p in model.parameters())
|
| 252 |
+
print(f"✅ CustomVLM: {n_train:,} trainable / {n_total:,} total params")
|
| 253 |
+
print(f" (Projection + Decoder trainable — {n_train/n_total*100:.2f}%)")
|
| 254 |
+
|
| 255 |
+
train_loader, val_loader = get_custom_vlm_dataloader(cfg, char_to_idx)
|
| 256 |
+
|
| 257 |
+
# Discriminative learning rates: projection (higher) + decoder (gentler)
|
| 258 |
+
param_groups = model.get_param_groups(
|
| 259 |
+
projection_lr=cfg.lr, # 1e-4
|
| 260 |
+
decoder_lr=cfg.lr * 0.5, # 5e-5
|
| 261 |
+
)
|
| 262 |
+
optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay)
|
| 263 |
+
total_steps = math.ceil(len(train_loader) / cfg.grad_accum) * cfg.epochs
|
| 264 |
+
warmup_steps = int(total_steps * cfg.warmup_ratio)
|
| 265 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
|
| 266 |
+
|
| 267 |
+
latest_dir, best_dir = get_output_paths(cfg, "custom_vlm")
|
| 268 |
+
|
| 269 |
+
# Metrics history
|
| 270 |
+
best_cider = -1.0
|
| 271 |
+
cider_scorer = Cider()
|
| 272 |
+
|
| 273 |
+
model.train()
|
| 274 |
+
global_step = 0
|
| 275 |
+
t0 = time.time()
|
| 276 |
+
|
| 277 |
+
for epoch in range(1, cfg.epochs + 1):
|
| 278 |
+
model.train()
|
| 279 |
+
pbar = tqdm(train_loader, desc=f"[CustomVLM] Epoch {epoch}/{cfg.epochs}")
|
| 280 |
+
running_loss = 0.0
|
| 281 |
+
epoch_loss_sum = 0.0
|
| 282 |
+
epoch_batches = 0
|
| 283 |
+
optimizer.zero_grad(set_to_none=True)
|
| 284 |
+
|
| 285 |
+
for i, batch in enumerate(pbar, start=1):
|
| 286 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 287 |
+
text_input_ids = batch["text_input_ids"].to(device)
|
| 288 |
+
text_targets = batch["text_targets"].to(device)
|
| 289 |
+
|
| 290 |
+
_, loss = model(pixel_values, text_input_ids, text_targets)
|
| 291 |
+
(loss / cfg.grad_accum).backward()
|
| 292 |
+
running_loss += loss.item()
|
| 293 |
+
epoch_loss_sum += loss.item()
|
| 294 |
+
epoch_batches += 1
|
| 295 |
+
|
| 296 |
+
if i % cfg.grad_accum == 0 or i == len(train_loader):
|
| 297 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
|
| 298 |
+
optimizer.step()
|
| 299 |
+
scheduler.step()
|
| 300 |
+
optimizer.zero_grad(set_to_none=True)
|
| 301 |
+
global_step += 1
|
| 302 |
+
|
| 303 |
+
if global_step % cfg.log_every == 0:
|
| 304 |
+
avg = running_loss / cfg.log_every
|
| 305 |
+
running_loss = 0.0
|
| 306 |
+
pbar.set_postfix({"loss": f"{avg:.4f}",
|
| 307 |
+
"lr": f"{scheduler.get_last_lr()[0]:.2e}"})
|
| 308 |
+
|
| 309 |
+
# End of epoch metrics
|
| 310 |
+
epoch_avg_loss = epoch_loss_sum / max(epoch_batches, 1)
|
| 311 |
+
print(f"\n📊 Epoch {epoch}/{cfg.epochs} avg loss (Train): {epoch_avg_loss:.4f}")
|
| 312 |
+
|
| 313 |
+
# --- Validation Loop ---
|
| 314 |
+
model.eval()
|
| 315 |
+
val_loss_sum = 0.0
|
| 316 |
+
val_batches = 0
|
| 317 |
+
ref_dict = {}
|
| 318 |
+
hyp_dict = {}
|
| 319 |
+
|
| 320 |
+
# Use a small subset for quick CIDEr eval during training
|
| 321 |
+
max_eval_batches = 10
|
| 322 |
+
print(" 🔍 Running Validation (Loss & CIDEr)...")
|
| 323 |
+
|
| 324 |
+
with torch.no_grad():
|
| 325 |
+
for i, batch in enumerate(val_loader):
|
| 326 |
+
if i >= max_eval_batches:
|
| 327 |
+
break
|
| 328 |
+
|
| 329 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 330 |
+
text_input_ids = batch["text_input_ids"].to(device)
|
| 331 |
+
text_targets = batch["text_targets"].to(device)
|
| 332 |
+
|
| 333 |
+
# 1. Validation Loss
|
| 334 |
+
_, loss = model(pixel_values, text_input_ids, text_targets)
|
| 335 |
+
val_loss_sum += loss.item()
|
| 336 |
+
val_batches += 1
|
| 337 |
+
|
| 338 |
+
# 2. Generation for CIDEr — iterate per sample (generate expects single image)
|
| 339 |
+
B = pixel_values.shape[0]
|
| 340 |
+
for b in range(B):
|
| 341 |
+
pv_single = pixel_values[b:b+1]
|
| 342 |
+
gen_caption = model.generate(pv_single, char_to_idx, idx_to_char, max_new_tokens=40)
|
| 343 |
+
|
| 344 |
+
tgt_cpu = text_targets[b].cpu().tolist()
|
| 345 |
+
true_str = "".join([idx_to_char.get(c, "") for c in tgt_cpu if c > 0])
|
| 346 |
+
|
| 347 |
+
img_id = f"{epoch}_{i}_{b}"
|
| 348 |
+
ref_dict[img_id] = [true_str]
|
| 349 |
+
hyp_dict[img_id] = [gen_caption]
|
| 350 |
+
|
| 351 |
+
val_avg_loss = val_loss_sum / max(val_batches, 1)
|
| 352 |
+
print(f" 📉 Validation Loss: {val_avg_loss:.4f}")
|
| 353 |
+
|
| 354 |
+
# Calculate CIDEr
|
| 355 |
+
try:
|
| 356 |
+
cider_score, _ = cider_scorer.compute_score(ref_dict, hyp_dict)
|
| 357 |
+
except Exception:
|
| 358 |
+
cider_score = 0.0
|
| 359 |
+
|
| 360 |
+
print(f" 🎯 Validation CIDEr: {cider_score:.4f}")
|
| 361 |
+
|
| 362 |
+
# Save latest (always)
|
| 363 |
+
_save_custom(model, char_to_idx, idx_to_char, cfg,
|
| 364 |
+
global_step, epoch, latest_dir)
|
| 365 |
+
print(f" 💾 Saved → {latest_dir}")
|
| 366 |
+
|
| 367 |
+
# Save best (based on highest CIDEr score)
|
| 368 |
+
if cider_score >= best_cider:
|
| 369 |
+
best_cider = cider_score
|
| 370 |
+
_save_custom(model, char_to_idx, idx_to_char, cfg,
|
| 371 |
+
global_step, epoch, best_dir)
|
| 372 |
+
print(f" 🏆 New best CIDEr (score={best_cider:.4f}) → {best_dir}")
|
| 373 |
+
|
| 374 |
+
elapsed = (time.time() - t0) / 60.0
|
| 375 |
+
print(f"\n✅ CustomVLM training complete in {elapsed:.2f} minutes")
|
| 376 |
+
print(f" Best validation CIDEr: {best_cider:.4f}")
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def _save_custom(model, char_to_idx, idx_to_char, cfg, step, epoch, save_dir):
|
| 380 |
+
"""Save CustomVLM checkpoint to the given directory (overwrites previous)."""
|
| 381 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 382 |
+
torch.save({
|
| 383 |
+
"model_state": model.state_dict(),
|
| 384 |
+
"char_to_idx": char_to_idx,
|
| 385 |
+
"idx_to_char": idx_to_char,
|
| 386 |
+
"config": {
|
| 387 |
+
"block_size": cfg.block_size,
|
| 388 |
+
"text_embed_dim": cfg.text_embed_dim,
|
| 389 |
+
"n_heads": cfg.n_heads,
|
| 390 |
+
"n_layers": cfg.n_layers,
|
| 391 |
+
"vocab_size": len(char_to_idx),
|
| 392 |
+
},
|
| 393 |
+
"step": step, "epoch": epoch,
|
| 394 |
+
}, os.path.join(save_dir, "custom_vlm.pt"))
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 398 |
+
# Main
|
| 399 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 400 |
+
|
| 401 |
+
def main():
|
| 402 |
+
parser = argparse.ArgumentParser(description="Train VLM — BLIP | ViT-GPT2 | GIT | Custom")
|
| 403 |
+
parser.add_argument(
|
| 404 |
+
"--model", type=str, default="blip",
|
| 405 |
+
choices=["blip", "vit_gpt2", "git", "custom"],
|
| 406 |
+
help="Which architecture to train",
|
| 407 |
+
)
|
| 408 |
+
args = parser.parse_args()
|
| 409 |
+
|
| 410 |
+
cfg = CFG.load_for_model(args.model)
|
| 411 |
+
device = get_device()
|
| 412 |
+
print(f"✅ Device: {device}")
|
| 413 |
+
print(f"✅ Config: {args.model} | epochs={cfg.epochs} | lr={cfg.lr} | "
|
| 414 |
+
f"batch_size={cfg.batch_size} | max_target_len={cfg.max_target_len}")
|
| 415 |
+
print(f"✅ Output: {cfg.output_root}/{args.model}/")
|
| 416 |
+
|
| 417 |
+
# ── Custom VLM has its own dedicated loop ──────────────────────────────
|
| 418 |
+
if args.model == "custom":
|
| 419 |
+
train_custom_vlm(cfg, device)
|
| 420 |
+
return
|
| 421 |
+
|
| 422 |
+
# ── HuggingFace Models ─────────────────────────────────────────────────
|
| 423 |
+
latest_dir, best_dir = get_output_paths(cfg, args.model)
|
| 424 |
+
|
| 425 |
+
processor = None
|
| 426 |
+
tokenizer = None
|
| 427 |
+
|
| 428 |
+
if args.model == "blip":
|
| 429 |
+
model, processor = get_blip_model(cfg, device)
|
| 430 |
+
train_loader, val_loader = get_dataloaders(cfg, processor)
|
| 431 |
+
|
| 432 |
+
def save_latest_fn(step, epoch):
|
| 433 |
+
blip_save(model, processor, None, None, step, epoch, cfg.__dict__, latest_dir)
|
| 434 |
+
|
| 435 |
+
def save_best_fn(step, epoch):
|
| 436 |
+
blip_save(model, processor, None, None, step, epoch, cfg.__dict__, best_dir)
|
| 437 |
+
|
| 438 |
+
elif args.model == "vit_gpt2":
|
| 439 |
+
model, processor, tokenizer = get_vit_gpt2_model(cfg, device)
|
| 440 |
+
train_loader, val_loader = get_dataloaders_for_model(cfg, "vit_gpt2", processor, tokenizer)
|
| 441 |
+
|
| 442 |
+
def save_latest_fn(step, epoch):
|
| 443 |
+
vit_gpt2_save(model, processor, tokenizer, None, None, step, epoch, cfg.__dict__, latest_dir)
|
| 444 |
+
|
| 445 |
+
def save_best_fn(step, epoch):
|
| 446 |
+
vit_gpt2_save(model, processor, tokenizer, None, None, step, epoch, cfg.__dict__, best_dir)
|
| 447 |
+
|
| 448 |
+
elif args.model == "git":
|
| 449 |
+
model, processor = get_git_model(cfg, device)
|
| 450 |
+
train_loader, val_loader = get_dataloaders_for_model(cfg, "git", processor)
|
| 451 |
+
|
| 452 |
+
def save_latest_fn(step, epoch):
|
| 453 |
+
git_save(model, processor, None, None, step, epoch, cfg.__dict__, latest_dir)
|
| 454 |
+
|
| 455 |
+
def save_best_fn(step, epoch):
|
| 456 |
+
git_save(model, processor, None, None, step, epoch, cfg.__dict__, best_dir)
|
| 457 |
+
|
| 458 |
+
optimizer = AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
|
| 459 |
+
total_steps = math.ceil(len(train_loader) / cfg.grad_accum) * cfg.epochs
|
| 460 |
+
warmup_steps = int(total_steps * cfg.warmup_ratio)
|
| 461 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
|
| 462 |
+
print(f"✅ Update steps: {total_steps} | Warmup: {warmup_steps}")
|
| 463 |
+
|
| 464 |
+
run_training_loop(model, optimizer, scheduler, train_loader, val_loader, cfg,
|
| 465 |
+
save_latest_fn=save_latest_fn,
|
| 466 |
+
save_best_fn=save_best_fn,
|
| 467 |
+
model_name=args.model.upper(),
|
| 468 |
+
processor=processor, tokenizer=tokenizer)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
if __name__ == "__main__":
|
| 472 |
+
main()
|
transformer2.ipynb
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 112,
|
| 6 |
+
"id": "5f1bb753",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"with open(\"input.txt\", \"r\") as f:\n",
|
| 11 |
+
" text = f.read()"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 113,
|
| 17 |
+
"id": "9cf7e7ac",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [
|
| 20 |
+
{
|
| 21 |
+
"name": "stdout",
|
| 22 |
+
"output_type": "stream",
|
| 23 |
+
"text": [
|
| 24 |
+
"Length of text: 1115394 characters\n",
|
| 25 |
+
"\n",
|
| 26 |
+
" !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
|
| 27 |
+
"Vocab size: 65\n"
|
| 28 |
+
]
|
| 29 |
+
}
|
| 30 |
+
],
|
| 31 |
+
"source": [
|
| 32 |
+
"length = len(text)\n",
|
| 33 |
+
"print(f\"Length of text: {length} characters\")\n",
|
| 34 |
+
"char = sorted(list(set(text)))\n",
|
| 35 |
+
"vocab_size = len(char)\n",
|
| 36 |
+
"print(\"\".join(char))\n",
|
| 37 |
+
"print(f\"Vocab size: {vocab_size}\")"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": 114,
|
| 43 |
+
"id": "1b910dc7",
|
| 44 |
+
"metadata": {},
|
| 45 |
+
"outputs": [
|
| 46 |
+
{
|
| 47 |
+
"name": "stdout",
|
| 48 |
+
"output_type": "stream",
|
| 49 |
+
"text": [
|
| 50 |
+
"[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]\n",
|
| 51 |
+
"hello world\n"
|
| 52 |
+
]
|
| 53 |
+
}
|
| 54 |
+
],
|
| 55 |
+
"source": [
|
| 56 |
+
"stoi = {ch:i for i,ch in enumerate(char)}\n",
|
| 57 |
+
"itos = {i:ch for i,ch in enumerate(char)}\n",
|
| 58 |
+
"encode = lambda s: [stoi[c] for c in s]\n",
|
| 59 |
+
"decode = lambda l: \"\".join([itos[i] for i in l])\n",
|
| 60 |
+
"print(encode(\"hello world\"))\n",
|
| 61 |
+
"print(decode(encode(\"hello world\"))) # note this is one of the simplest possible tokenizers, it just maps each character to an integer. everyone has their own tokenizer like google use sentencepiece, openai use bpe, etc. we will build our own tokenizer in the next notebook."
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": 115,
|
| 67 |
+
"id": "3d287813",
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"outputs": [
|
| 70 |
+
{
|
| 71 |
+
"data": {
|
| 72 |
+
"text/plain": [
|
| 73 |
+
"<torch._C.Generator at 0x11113bdd0>"
|
| 74 |
+
]
|
| 75 |
+
},
|
| 76 |
+
"execution_count": 115,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"output_type": "execute_result"
|
| 79 |
+
}
|
| 80 |
+
],
|
| 81 |
+
"source": [
|
| 82 |
+
"import torch\n",
|
| 83 |
+
"import torch.nn as nn\n",
|
| 84 |
+
"import torch.nn.functional as F\n",
|
| 85 |
+
"torch.manual_seed(42)"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"execution_count": 116,
|
| 91 |
+
"id": "4786dcce",
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"outputs": [
|
| 94 |
+
{
|
| 95 |
+
"name": "stdout",
|
| 96 |
+
"output_type": "stream",
|
| 97 |
+
"text": [
|
| 98 |
+
"torch.Size([1115394]) torch.int64\n",
|
| 99 |
+
"tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n",
|
| 100 |
+
" 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n",
|
| 101 |
+
" 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n",
|
| 102 |
+
" 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n",
|
| 103 |
+
" 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n",
|
| 104 |
+
" 58, 47, 64, 43, 52, 10, 0, 37, 53, 59])\n"
|
| 105 |
+
]
|
| 106 |
+
}
|
| 107 |
+
],
|
| 108 |
+
"source": [
|
| 109 |
+
"data = torch.tensor(encode(text), dtype=torch.long)\n",
|
| 110 |
+
"print(data.shape, data.dtype)\n",
|
| 111 |
+
"print(data[:100])"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": 117,
|
| 117 |
+
"id": "ee9c3b71",
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"outputs": [
|
| 120 |
+
{
|
| 121 |
+
"name": "stdout",
|
| 122 |
+
"output_type": "stream",
|
| 123 |
+
"text": [
|
| 124 |
+
"torch.Size([1003854]) torch.Size([111540])\n"
|
| 125 |
+
]
|
| 126 |
+
}
|
| 127 |
+
],
|
| 128 |
+
"source": [
|
| 129 |
+
"n = int(0.9*len(data))\n",
|
| 130 |
+
"train_data = data[:n]\n",
|
| 131 |
+
"val_data = data[n:]\n",
|
| 132 |
+
"print(train_data.shape, val_data.shape)"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": 118,
|
| 138 |
+
"id": "14d2fe85",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [
|
| 141 |
+
{
|
| 142 |
+
"name": "stdout",
|
| 143 |
+
"output_type": "stream",
|
| 144 |
+
"text": [
|
| 145 |
+
"tensor([18, 47, 56, 57, 58, 1, 15, 47, 58])\n"
|
| 146 |
+
]
|
| 147 |
+
}
|
| 148 |
+
],
|
| 149 |
+
"source": [
|
| 150 |
+
"block_size = 8\n",
|
| 151 |
+
"train_data[:block_size+1] # we will use the first 8 characters to predict\n",
|
| 152 |
+
"print(train_data[:block_size+1])"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "code",
|
| 157 |
+
"execution_count": 119,
|
| 158 |
+
"id": "a690a090",
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [
|
| 161 |
+
{
|
| 162 |
+
"name": "stdout",
|
| 163 |
+
"output_type": "stream",
|
| 164 |
+
"text": [
|
| 165 |
+
"using mps device\n"
|
| 166 |
+
]
|
| 167 |
+
}
|
| 168 |
+
],
|
| 169 |
+
"source": [
|
| 170 |
+
"#use mps as i am using the mac with m4 \n",
|
| 171 |
+
"batch_size = 64 # how many independent sequences will we process in parallel?\n",
|
| 172 |
+
"block_size = 256 # what is the maximum context length for predictions?\n",
|
| 173 |
+
"n_embeed = 384 \n",
|
| 174 |
+
"max_iters = 20000\n",
|
| 175 |
+
"eval_iters = 2000 \n",
|
| 176 |
+
"lr_rate = 2e-4\n",
|
| 177 |
+
"dropout = 0.2\n",
|
| 178 |
+
"n_layer = 8\n",
|
| 179 |
+
"n_head = 8\n",
|
| 180 |
+
"device = \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
|
| 181 |
+
"print(f\"using {device} device\")"
|
| 182 |
+
]
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"cell_type": "code",
|
| 186 |
+
"execution_count": 120,
|
| 187 |
+
"id": "d90a7d94",
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [
|
| 190 |
+
{
|
| 191 |
+
"name": "stdout",
|
| 192 |
+
"output_type": "stream",
|
| 193 |
+
"text": [
|
| 194 |
+
"inputs:\n",
|
| 195 |
+
"torch.Size([64, 256])\n",
|
| 196 |
+
"tensor([[ 0, 26, 53, ..., 56, 43, 47],\n",
|
| 197 |
+
" [60, 43, 56, ..., 56, 1, 41],\n",
|
| 198 |
+
" [26, 21, 33, ..., 26, 21, 13],\n",
|
| 199 |
+
" ...,\n",
|
| 200 |
+
" [ 5, 57, 1, ..., 1, 35, 47],\n",
|
| 201 |
+
" [56, 53, 53, ..., 59, 50, 42],\n",
|
| 202 |
+
" [42, 47, 56, ..., 39, 56, 1]], device='mps:0')\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"targets:\n",
|
| 205 |
+
"torch.Size([64, 256])\n",
|
| 206 |
+
"tensor([[26, 53, 58, ..., 43, 47, 45],\n",
|
| 207 |
+
" [43, 56, 1, ..., 1, 41, 53],\n",
|
| 208 |
+
" [21, 33, 31, ..., 21, 13, 10],\n",
|
| 209 |
+
" ...,\n",
|
| 210 |
+
" [57, 1, 52, ..., 35, 47, 50],\n",
|
| 211 |
+
" [53, 53, 58, ..., 50, 42, 1],\n",
|
| 212 |
+
" [47, 56, 43, ..., 56, 1, 51]], device='mps:0')\n"
|
| 213 |
+
]
|
| 214 |
+
}
|
| 215 |
+
],
|
| 216 |
+
"source": [
|
| 217 |
+
"torch.manual_seed(1337)\n",
|
| 218 |
+
"def get_batch(split):\n",
|
| 219 |
+
" data = train_data if split == 'train' else val_data\n",
|
| 220 |
+
" ix = torch.randint(len(data) - block_size, (batch_size,))\n",
|
| 221 |
+
" x = torch.stack([data[i:i+block_size] for i in ix])\n",
|
| 222 |
+
" y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
|
| 223 |
+
" x, y = x.to(device), y.to(device)\n",
|
| 224 |
+
" return x, y\n",
|
| 225 |
+
"xb, yb = get_batch('train')\n",
|
| 226 |
+
"print(\"inputs:\")\n",
|
| 227 |
+
"print(xb.shape)\n",
|
| 228 |
+
"print(xb)\n",
|
| 229 |
+
"print(\"\\ntargets:\")\n",
|
| 230 |
+
"print(yb.shape)\n",
|
| 231 |
+
"print(yb)"
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"cell_type": "code",
|
| 236 |
+
"execution_count": 121,
|
| 237 |
+
"id": "27573f3f",
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"outputs": [],
|
| 240 |
+
"source": [
|
| 241 |
+
"class Head(torch.nn.Module):\n",
|
| 242 |
+
" def __init__(self, head_size):\n",
|
| 243 |
+
" super().__init__()\n",
|
| 244 |
+
" self.head_size = head_size\n",
|
| 245 |
+
" self.key = nn.Linear(n_embeed, head_size, bias=False)\n",
|
| 246 |
+
" self.query = nn.Linear(n_embeed, head_size, bias=False)\n",
|
| 247 |
+
" self.value = nn.Linear(n_embeed, head_size, bias=False)\n",
|
| 248 |
+
" self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
|
| 249 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
| 250 |
+
"\n",
|
| 251 |
+
" def forward(self, x):\n",
|
| 252 |
+
" B,T,C = x.shape\n",
|
| 253 |
+
" k = self.key(x) # (B,T,16)\n",
|
| 254 |
+
" q = self.query(x) # (B,T,16)\n",
|
| 255 |
+
" v = self.value(x) # (B,T,16)\n",
|
| 256 |
+
" weights = q @ k.transpose(-2, -1) * (self.head_size ** -0.5) # (B,T,16) @ (B,16,T) -> (B,T,T)\n",
|
| 257 |
+
" tril = torch.tril(torch.ones(T, T , device=x.device))\n",
|
| 258 |
+
" weights = weights.masked_fill(tril == 0, float('-inf')) # when we talk about the encoder transformer we remove this bcz we want to attend to all the tokens in the input sequence, but in the decoder transformer we want to attend only to the previous tokens in the output sequence, so we use this mask to prevent the model from attending to future tokens.\n",
|
| 259 |
+
" weights = torch.softmax(weights, dim=-1)\n",
|
| 260 |
+
" weights = self.dropout(weights)\n",
|
| 261 |
+
" out = weights @ v # (B,T,T) @ (B,T,C) -> (B,T,C)\n",
|
| 262 |
+
" return out\n",
|
| 263 |
+
" "
|
| 264 |
+
]
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"cell_type": "code",
|
| 268 |
+
"execution_count": 122,
|
| 269 |
+
"id": "a776b854",
|
| 270 |
+
"metadata": {},
|
| 271 |
+
"outputs": [],
|
| 272 |
+
"source": [
|
| 273 |
+
"class MultiHeadAttention(torch.nn.Module):\n",
|
| 274 |
+
" def __init__(self, num_heads, head_size):\n",
|
| 275 |
+
" super().__init__()\n",
|
| 276 |
+
" self.heads = torch.nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
|
| 277 |
+
" self.proj = torch.nn.Linear(n_embeed, n_embeed)\n",
|
| 278 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
| 279 |
+
" def forward(self, x):\n",
|
| 280 |
+
" out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
|
| 281 |
+
" out = self.proj(out)\n",
|
| 282 |
+
" return out"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"cell_type": "code",
|
| 287 |
+
"execution_count": 123,
|
| 288 |
+
"id": "da0d9201",
|
| 289 |
+
"metadata": {},
|
| 290 |
+
"outputs": [],
|
| 291 |
+
"source": [
|
| 292 |
+
"class feedforward(torch.nn.Module):\n",
|
| 293 |
+
" def __init__(self, n_embeed):\n",
|
| 294 |
+
" super().__init__()\n",
|
| 295 |
+
" self.net = torch.nn.Sequential(\n",
|
| 296 |
+
" torch.nn.Linear(n_embeed, 4*n_embeed), #according to paper there is the multiplier of 4 in the hidden layers \n",
|
| 297 |
+
" torch.nn.ReLU(),\n",
|
| 298 |
+
" torch.nn.Linear(4*n_embeed, n_embeed),\n",
|
| 299 |
+
" nn.Dropout(dropout)\n",
|
| 300 |
+
" )\n",
|
| 301 |
+
" def forward(self, x):\n",
|
| 302 |
+
" return self.net(x)"
|
| 303 |
+
]
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"cell_type": "code",
|
| 307 |
+
"execution_count": 124,
|
| 308 |
+
"id": "1b2fc012",
|
| 309 |
+
"metadata": {},
|
| 310 |
+
"outputs": [],
|
| 311 |
+
"source": [
|
| 312 |
+
"class Block(torch.nn.Module):\n",
|
| 313 |
+
" def __init__(self, n_embeed , n_head):\n",
|
| 314 |
+
" super().__init__()\n",
|
| 315 |
+
" head_size = n_embeed // n_head \n",
|
| 316 |
+
" self.sa_head = MultiHeadAttention(num_heads=n_head, head_size=head_size)\n",
|
| 317 |
+
" self.ffwd = feedforward(n_embeed)\n",
|
| 318 |
+
" self.ln1 = nn.LayerNorm(n_embeed)\n",
|
| 319 |
+
" self.ln2 = nn.LayerNorm(n_embeed)\n",
|
| 320 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
| 321 |
+
" def forward(self, x):\n",
|
| 322 |
+
" x = x+self.sa_head(self.ln1(x))#this is slightly deviation from the original paper as we are passing the layer norem before the multi head attention and feedforward, but it is a common practice to do so, and it works better than the original paper.\n",
|
| 323 |
+
" x = x+self.ffwd(self.ln2(x))\n",
|
| 324 |
+
" return x"
|
| 325 |
+
]
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"cell_type": "code",
|
| 329 |
+
"execution_count": 125,
|
| 330 |
+
"id": "7a3053a0",
|
| 331 |
+
"metadata": {},
|
| 332 |
+
"outputs": [],
|
| 333 |
+
"source": [
|
| 334 |
+
"class BigramLanguageModel(torch.nn.Module):\n",
|
| 335 |
+
" def __init__(self, vocab_size, n_embeed):\n",
|
| 336 |
+
" super().__init__()\n",
|
| 337 |
+
" self.token_embedding_table = torch.nn.Embedding(vocab_size, n_embeed) \n",
|
| 338 |
+
" self.position_embedding_table = torch.nn.Embedding(block_size, n_embeed)\n",
|
| 339 |
+
" #not giving the good enhaced result when we try multiple Block() , deep neural net suffer from the optimisation issue \n",
|
| 340 |
+
" # self.blocks = nn.Sequential(\n",
|
| 341 |
+
" # Block(n_embeed, n_head=4),\n",
|
| 342 |
+
" # Block(n_embeed, n_head=4),\n",
|
| 343 |
+
" # Block(n_embeed, n_head=4),\n",
|
| 344 |
+
" # nn.LayerNorm(n_embeed),\n",
|
| 345 |
+
" # )\n",
|
| 346 |
+
" self.blocks = nn.Sequential(*[Block(n_embeed, n_head) for _ in range(n_layer)])\n",
|
| 347 |
+
" self.ln_f = nn.LayerNorm(n_embeed)\n",
|
| 348 |
+
" self.lm_head = torch.nn.Linear(n_embeed, vocab_size)\n",
|
| 349 |
+
" def forward(self, idx, targets=None):\n",
|
| 350 |
+
" B,T = idx.shape\n",
|
| 351 |
+
" # idx and targets are both (B,T) tensor of integers\n",
|
| 352 |
+
" token_emb = self.token_embedding_table(idx) # (B,T,C)\n",
|
| 353 |
+
" pos_emb = self.position_embedding_table(torch.arange(idx.shape[1], device=idx.device)) # (T,C)\n",
|
| 354 |
+
" x = token_emb + pos_emb # (B,T,C)\n",
|
| 355 |
+
" x = self.blocks(x) # (B,T,C)\n",
|
| 356 |
+
" x = self.ln_f(x) # (B,T,C)\n",
|
| 357 |
+
" logits = self.lm_head(x) # (B,T,vocab_size)\n",
|
| 358 |
+
" if targets is None:\n",
|
| 359 |
+
" loss = None\n",
|
| 360 |
+
" else:\n",
|
| 361 |
+
" B,T,C = logits.shape\n",
|
| 362 |
+
" logits = logits.view(B*T, C)\n",
|
| 363 |
+
" targets = targets.view(B*T)\n",
|
| 364 |
+
" loss = F.cross_entropy(logits, targets)\n",
|
| 365 |
+
" return logits, loss\n",
|
| 366 |
+
" \n",
|
| 367 |
+
" def generate(self, idx, max_new_tokens):\n",
|
| 368 |
+
" # idx is (B,T) array of indices in the current context\n",
|
| 369 |
+
" for _ in range(max_new_tokens):\n",
|
| 370 |
+
" idx_cond = idx[:, -block_size:] # crop idx to the last block_size tokens\n",
|
| 371 |
+
" logits, loss = self(idx_cond)\n",
|
| 372 |
+
" logits = logits[:, -1, :] # becomes (B,C) , as we only want to provide the last character as the input to predict the next character\n",
|
| 373 |
+
" probs = F.softmax(logits, dim=-1) # (B,C)\n",
|
| 374 |
+
" idx_next = torch.multinomial(probs, num_samples=1) # (B,1)\n",
|
| 375 |
+
" idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)\n",
|
| 376 |
+
" return idx"
|
| 377 |
+
]
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
"cell_type": "code",
|
| 381 |
+
"execution_count": null,
|
| 382 |
+
"id": "67e96f0b",
|
| 383 |
+
"metadata": {},
|
| 384 |
+
"outputs": [],
|
| 385 |
+
"source": []
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"cell_type": "code",
|
| 389 |
+
"execution_count": 126,
|
| 390 |
+
"id": "0e9d66e8",
|
| 391 |
+
"metadata": {},
|
| 392 |
+
"outputs": [
|
| 393 |
+
{
|
| 394 |
+
"name": "stdout",
|
| 395 |
+
"output_type": "stream",
|
| 396 |
+
"text": [
|
| 397 |
+
"logits shape: torch.Size([16384, 65])\n",
|
| 398 |
+
"loss: 4.277037620544434\n",
|
| 399 |
+
"\n",
|
| 400 |
+
"tRNt'OUWzNdaNv;DZ!HWJxsg-rG$l.\n",
|
| 401 |
+
"VXx;h&CEqoyJOlF.DmdMw;u;cjEIgcOQOID;$wig.tRIgazPSVyRpKBE-3UQBdJ'AIIxX\n"
|
| 402 |
+
]
|
| 403 |
+
}
|
| 404 |
+
],
|
| 405 |
+
"source": [
|
| 406 |
+
"model = BigramLanguageModel(vocab_size, n_embeed)\n",
|
| 407 |
+
"model = model.to(device)\n",
|
| 408 |
+
"logits, loss = model(xb, yb)\n",
|
| 409 |
+
"print(\"logits shape:\", logits.shape)\n",
|
| 410 |
+
"print(\"loss:\", loss.item())\n",
|
| 411 |
+
"print(decode(model.generate(idx=torch.zeros((1,1), dtype=torch.long, device=device), max_new_tokens=100)[0].tolist()))"
|
| 412 |
+
]
|
| 413 |
+
},
|
| 414 |
+
{
|
| 415 |
+
"cell_type": "code",
|
| 416 |
+
"execution_count": 127,
|
| 417 |
+
"id": "1da9dd4f",
|
| 418 |
+
"metadata": {},
|
| 419 |
+
"outputs": [],
|
| 420 |
+
"source": [
|
| 421 |
+
"@torch.no_grad()\n",
|
| 422 |
+
"def estimate_loss():\n",
|
| 423 |
+
" out = {}\n",
|
| 424 |
+
" model.eval()\n",
|
| 425 |
+
" for split in ['train', 'val']:\n",
|
| 426 |
+
" losses = torch.zeros(eval_iters)\n",
|
| 427 |
+
" for k in range(eval_iters):\n",
|
| 428 |
+
" X, Y = get_batch(split)\n",
|
| 429 |
+
" logits, loss = model(X, Y)\n",
|
| 430 |
+
" losses[k] = loss.item()\n",
|
| 431 |
+
" out[split] = losses.mean()\n",
|
| 432 |
+
" model.train()\n",
|
| 433 |
+
" return out"
|
| 434 |
+
]
|
| 435 |
+
},
|
| 436 |
+
{
|
| 437 |
+
"cell_type": "code",
|
| 438 |
+
"execution_count": null,
|
| 439 |
+
"id": "1e3fb308",
|
| 440 |
+
"metadata": {},
|
| 441 |
+
"outputs": [
|
| 442 |
+
{
|
| 443 |
+
"name": "stdout",
|
| 444 |
+
"output_type": "stream",
|
| 445 |
+
"text": [
|
| 446 |
+
"step 0: train loss 4.2785, val loss 4.2821\n"
|
| 447 |
+
]
|
| 448 |
+
}
|
| 449 |
+
],
|
| 450 |
+
"source": [
|
| 451 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=lr_rate)\n",
|
| 452 |
+
"for steps in range(max_iters):\n",
|
| 453 |
+
" if steps % eval_iters == 0:\n",
|
| 454 |
+
" losses = estimate_loss()\n",
|
| 455 |
+
" print(f\"step {steps}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
|
| 456 |
+
"\n",
|
| 457 |
+
" xb, yb = get_batch('train')\n",
|
| 458 |
+
" logits, loss = model(xb, yb)\n",
|
| 459 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
| 460 |
+
" loss.backward()\n",
|
| 461 |
+
" optimizer.step()\n",
|
| 462 |
+
"\n"
|
| 463 |
+
]
|
| 464 |
+
},
|
| 465 |
+
{
|
| 466 |
+
"cell_type": "code",
|
| 467 |
+
"execution_count": null,
|
| 468 |
+
"id": "9490a27b",
|
| 469 |
+
"metadata": {},
|
| 470 |
+
"outputs": [
|
| 471 |
+
{
|
| 472 |
+
"name": "stdout",
|
| 473 |
+
"output_type": "stream",
|
| 474 |
+
"text": [
|
| 475 |
+
"\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"DUKE VINCENTIO:\n",
|
| 478 |
+
"Stand brother, sir, here it, uncle he got.\n",
|
| 479 |
+
"\n",
|
| 480 |
+
"VIRGILIA:\n",
|
| 481 |
+
"A dog of the yousician, let your good brother, sister,\n",
|
| 482 |
+
"nor it to die.\n",
|
| 483 |
+
"\n",
|
| 484 |
+
"VOLUMNIA:\n",
|
| 485 |
+
"She is in the mar, and the matter:\n",
|
| 486 |
+
"there is! What say you, Juliet alone and bird.\n",
|
| 487 |
+
"Is thy life?\n",
|
| 488 |
+
"\n",
|
| 489 |
+
"JULIET:\n",
|
| 490 |
+
"Being a child! prompt fear: speak, and look fellow good?\n",
|
| 491 |
+
"\n",
|
| 492 |
+
"FLORIZEL:\n",
|
| 493 |
+
"And rumour, by my man's tooth made.\n",
|
| 494 |
+
"\n",
|
| 495 |
+
"JULIET:\n",
|
| 496 |
+
"Ay, if you doth make leave your retires,\n",
|
| 497 |
+
"A mother tempt my todder dial should have\n",
|
| 498 |
+
"So dear and let me 'gainst words out again,\n",
|
| 499 |
+
"Savest with honour's princes to hear throught of,\n",
|
| 500 |
+
"His perdom of preserve\n",
|
| 501 |
+
"Is posterity and secut\n",
|
| 502 |
+
"No god costerb: shall be more the Capitol,\n",
|
| 503 |
+
"But court did this hoursest, do begg them buried.\n",
|
| 504 |
+
"His apple and dreams on daughter, and we will,\n",
|
| 505 |
+
"He were laddy's wounds. O mother!\n",
|
| 506 |
+
"Dread!\n",
|
| 507 |
+
"In it the whitest through thee grief: why, general,\n",
|
| 508 |
+
"My heart play'd many fellows upon him.\n",
|
| 509 |
+
"\n",
|
| 510 |
+
"FRIAR LAURENCE:\n",
|
| 511 |
+
"For traitor the mind: what the journey, rise!\n",
|
| 512 |
+
"I serve, or I know the senate, and let my indeed\n",
|
| 513 |
+
"Will on brave it so lone\n"
|
| 514 |
+
]
|
| 515 |
+
}
|
| 516 |
+
],
|
| 517 |
+
"source": [
|
| 518 |
+
"print(decode(model.generate(idx=torch.zeros((1,1), dtype=torch.long, device=device), max_new_tokens=1000)[0].tolist()))"
|
| 519 |
+
]
|
| 520 |
+
},
|
| 521 |
+
{
|
| 522 |
+
"cell_type": "code",
|
| 523 |
+
"execution_count": null,
|
| 524 |
+
"id": "d717cdc1",
|
| 525 |
+
"metadata": {},
|
| 526 |
+
"outputs": [
|
| 527 |
+
{
|
| 528 |
+
"name": "stdout",
|
| 529 |
+
"output_type": "stream",
|
| 530 |
+
"text": [
|
| 531 |
+
"10.788929 M parameters\n"
|
| 532 |
+
]
|
| 533 |
+
}
|
| 534 |
+
],
|
| 535 |
+
"source": [
|
| 536 |
+
"print(sum(p.numel() for p in model.parameters())/1e6, \"M parameters\")"
|
| 537 |
+
]
|
| 538 |
+
},
|
| 539 |
+
{
|
| 540 |
+
"cell_type": "code",
|
| 541 |
+
"execution_count": null,
|
| 542 |
+
"id": "58991844",
|
| 543 |
+
"metadata": {},
|
| 544 |
+
"outputs": [],
|
| 545 |
+
"source": [
|
| 546 |
+
"import torch\n",
|
| 547 |
+
"torch.save(model.state_dict(), \"shakespeare_transformer.pt\")"
|
| 548 |
+
]
|
| 549 |
+
},
|
| 550 |
+
{
|
| 551 |
+
"cell_type": "code",
|
| 552 |
+
"execution_count": null,
|
| 553 |
+
"id": "935f7d0e",
|
| 554 |
+
"metadata": {},
|
| 555 |
+
"outputs": [],
|
| 556 |
+
"source": []
|
| 557 |
+
}
|
| 558 |
+
],
|
| 559 |
+
"metadata": {
|
| 560 |
+
"kernelspec": {
|
| 561 |
+
"display_name": "Python 3",
|
| 562 |
+
"language": "python",
|
| 563 |
+
"name": "python3"
|
| 564 |
+
},
|
| 565 |
+
"language_info": {
|
| 566 |
+
"codemirror_mode": {
|
| 567 |
+
"name": "ipython",
|
| 568 |
+
"version": 3
|
| 569 |
+
},
|
| 570 |
+
"file_extension": ".py",
|
| 571 |
+
"mimetype": "text/x-python",
|
| 572 |
+
"name": "python",
|
| 573 |
+
"nbconvert_exporter": "python",
|
| 574 |
+
"pygments_lexer": "ipython3",
|
| 575 |
+
"version": "3.14.2"
|
| 576 |
+
}
|
| 577 |
+
},
|
| 578 |
+
"nbformat": 4,
|
| 579 |
+
"nbformat_minor": 5
|
| 580 |
+
}
|
transformer_base.ipynb
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"id": "193c3159",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"with open(\"input.txt\", \"r\") as f:\n",
|
| 11 |
+
" text = f.read()"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 3,
|
| 17 |
+
"id": "e557cb70",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [
|
| 20 |
+
{
|
| 21 |
+
"name": "stdout",
|
| 22 |
+
"output_type": "stream",
|
| 23 |
+
"text": [
|
| 24 |
+
"Length of text: 1115394 characters\n"
|
| 25 |
+
]
|
| 26 |
+
}
|
| 27 |
+
],
|
| 28 |
+
"source": [
|
| 29 |
+
"length = len(text)\n",
|
| 30 |
+
"print(f\"Length of text: {length} characters\")"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": null,
|
| 36 |
+
"id": "750587a9",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"print(text[:500]) "
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "code",
|
| 45 |
+
"execution_count": 5,
|
| 46 |
+
"id": "16490999",
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [
|
| 49 |
+
{
|
| 50 |
+
"name": "stdout",
|
| 51 |
+
"output_type": "stream",
|
| 52 |
+
"text": [
|
| 53 |
+
"\n",
|
| 54 |
+
" !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
|
| 55 |
+
"Vocab size: 65\n"
|
| 56 |
+
]
|
| 57 |
+
}
|
| 58 |
+
],
|
| 59 |
+
"source": [
|
| 60 |
+
"char = sorted(list(set(text)))\n",
|
| 61 |
+
"vocab_size = len(char)\n",
|
| 62 |
+
"print(\"\".join(char))\n",
|
| 63 |
+
"print(f\"Vocab size: {vocab_size}\")"
|
| 64 |
+
]
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"cell_type": "code",
|
| 68 |
+
"execution_count": 39,
|
| 69 |
+
"id": "d9e6e17a",
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"outputs": [
|
| 72 |
+
{
|
| 73 |
+
"name": "stdout",
|
| 74 |
+
"output_type": "stream",
|
| 75 |
+
"text": [
|
| 76 |
+
"using mps device\n"
|
| 77 |
+
]
|
| 78 |
+
}
|
| 79 |
+
],
|
| 80 |
+
"source": [
|
| 81 |
+
"#use mps as i am using the mac with m4 \n",
|
| 82 |
+
"device = \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
|
| 83 |
+
"print(f\"using {device} device\")"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": 6,
|
| 89 |
+
"id": "082fd1ba",
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"outputs": [
|
| 92 |
+
{
|
| 93 |
+
"name": "stdout",
|
| 94 |
+
"output_type": "stream",
|
| 95 |
+
"text": [
|
| 96 |
+
"[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]\n",
|
| 97 |
+
"hello world\n"
|
| 98 |
+
]
|
| 99 |
+
}
|
| 100 |
+
],
|
| 101 |
+
"source": [
|
| 102 |
+
"stoi = {ch:i for i,ch in enumerate(char)}\n",
|
| 103 |
+
"itos = {i:ch for i,ch in enumerate(char)}\n",
|
| 104 |
+
"encode = lambda s: [stoi[c] for c in s]\n",
|
| 105 |
+
"decode = lambda l: \"\".join([itos[i] for i in l])\n",
|
| 106 |
+
"print(encode(\"hello world\"))\n",
|
| 107 |
+
"print(decode(encode(\"hello world\"))) # note this is one of the simplest possible tokenizers, it just maps each character to an integer. everyone has their own tokenizer like google use sentencepiece, openai use bpe, etc. we will build our own tokenizer in the next notebook."
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": 7,
|
| 113 |
+
"id": "7cce9365",
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"outputs": [
|
| 116 |
+
{
|
| 117 |
+
"name": "stdout",
|
| 118 |
+
"output_type": "stream",
|
| 119 |
+
"text": [
|
| 120 |
+
"torch.Size([1115394]) torch.int64\n",
|
| 121 |
+
"tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n",
|
| 122 |
+
" 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n",
|
| 123 |
+
" 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n",
|
| 124 |
+
" 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n",
|
| 125 |
+
" 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n",
|
| 126 |
+
" 58, 47, 64, 43, 52, 10, 0, 37, 53, 59])\n"
|
| 127 |
+
]
|
| 128 |
+
}
|
| 129 |
+
],
|
| 130 |
+
"source": [
|
| 131 |
+
"import torch\n",
|
| 132 |
+
"data = torch.tensor(encode(text), dtype=torch.long)\n",
|
| 133 |
+
"print(data.shape, data.dtype)\n",
|
| 134 |
+
"print(data[:100])"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": 8,
|
| 140 |
+
"id": "d59606cc",
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"outputs": [
|
| 143 |
+
{
|
| 144 |
+
"name": "stdout",
|
| 145 |
+
"output_type": "stream",
|
| 146 |
+
"text": [
|
| 147 |
+
"torch.Size([1003854]) torch.Size([111540])\n"
|
| 148 |
+
]
|
| 149 |
+
}
|
| 150 |
+
],
|
| 151 |
+
"source": [
|
| 152 |
+
"n = int(0.9*len(data))\n",
|
| 153 |
+
"train_data = data[:n]\n",
|
| 154 |
+
"val_data = data[n:]\n",
|
| 155 |
+
"print(train_data.shape, val_data.shape)"
|
| 156 |
+
]
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"cell_type": "code",
|
| 160 |
+
"execution_count": 9,
|
| 161 |
+
"id": "e2bd00e4",
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"outputs": [
|
| 164 |
+
{
|
| 165 |
+
"name": "stdout",
|
| 166 |
+
"output_type": "stream",
|
| 167 |
+
"text": [
|
| 168 |
+
"tensor([18, 47, 56, 57, 58, 1, 15, 47, 58])\n"
|
| 169 |
+
]
|
| 170 |
+
}
|
| 171 |
+
],
|
| 172 |
+
"source": [
|
| 173 |
+
"block_size = 8\n",
|
| 174 |
+
"train_data[:block_size+1] # we will use the first 8 characters to predict\n",
|
| 175 |
+
"print(train_data[:block_size+1])"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"cell_type": "code",
|
| 180 |
+
"execution_count": 11,
|
| 181 |
+
"id": "4ce6af03",
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"outputs": [
|
| 184 |
+
{
|
| 185 |
+
"name": "stdout",
|
| 186 |
+
"output_type": "stream",
|
| 187 |
+
"text": [
|
| 188 |
+
"when input is tensor([18]) the target: 47\n",
|
| 189 |
+
"when input is tensor([18, 47]) the target: 56\n",
|
| 190 |
+
"when input is tensor([18, 47, 56]) the target: 57\n",
|
| 191 |
+
"when input is tensor([18, 47, 56, 57]) the target: 58\n",
|
| 192 |
+
"when input is tensor([18, 47, 56, 57, 58]) the target: 1\n",
|
| 193 |
+
"when input is tensor([18, 47, 56, 57, 58, 1]) the target: 15\n",
|
| 194 |
+
"when input is tensor([18, 47, 56, 57, 58, 1, 15]) the target: 47\n",
|
| 195 |
+
"when input is tensor([18, 47, 56, 57, 58, 1, 15, 47]) the target: 58\n"
|
| 196 |
+
]
|
| 197 |
+
}
|
| 198 |
+
],
|
| 199 |
+
"source": [
|
| 200 |
+
"x_train = train_data[:block_size]\n",
|
| 201 |
+
"y_train = train_data[1:block_size+1]\n",
|
| 202 |
+
"for t in range(block_size):\n",
|
| 203 |
+
" context = x_train[:t+1]\n",
|
| 204 |
+
" target = y_train[t]\n",
|
| 205 |
+
" print(f\"when input is {context} the target: {target}\")"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "code",
|
| 210 |
+
"execution_count": 38,
|
| 211 |
+
"id": "85e56335",
|
| 212 |
+
"metadata": {},
|
| 213 |
+
"outputs": [
|
| 214 |
+
{
|
| 215 |
+
"name": "stdout",
|
| 216 |
+
"output_type": "stream",
|
| 217 |
+
"text": [
|
| 218 |
+
"inputs:\n",
|
| 219 |
+
"torch.Size([4, 8])\n",
|
| 220 |
+
"tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
|
| 221 |
+
" [44, 53, 56, 1, 58, 46, 39, 58],\n",
|
| 222 |
+
" [52, 58, 1, 58, 46, 39, 58, 1],\n",
|
| 223 |
+
" [25, 17, 27, 10, 0, 21, 1, 54]], device='mps:0')\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"targets:\n",
|
| 226 |
+
"torch.Size([4, 8])\n",
|
| 227 |
+
"tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n",
|
| 228 |
+
" [53, 56, 1, 58, 46, 39, 58, 1],\n",
|
| 229 |
+
" [58, 1, 58, 46, 39, 58, 1, 46],\n",
|
| 230 |
+
" [17, 27, 10, 0, 21, 1, 54, 39]], device='mps:0')\n"
|
| 231 |
+
]
|
| 232 |
+
}
|
| 233 |
+
],
|
| 234 |
+
"source": [
|
| 235 |
+
"torch.manual_seed(1337)\n",
|
| 236 |
+
"batch_size = 4 # how many independent sequences will we process in parallel?\n",
|
| 237 |
+
"block_size = 8 # what is the maximum context length for predictions?\n",
|
| 238 |
+
"def get_batch(split):\n",
|
| 239 |
+
" data = train_data if split == 'train' else val_data\n",
|
| 240 |
+
" ix = torch.randint(len(data) - block_size, (batch_size,))\n",
|
| 241 |
+
" x = torch.stack([data[i:i+block_size] for i in ix])\n",
|
| 242 |
+
" y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
|
| 243 |
+
" x, y = x.to(device), y.to(device)\n",
|
| 244 |
+
" return x, y\n",
|
| 245 |
+
"xb, yb = get_batch('train')\n",
|
| 246 |
+
"print(\"inputs:\")\n",
|
| 247 |
+
"print(xb.shape)\n",
|
| 248 |
+
"print(xb)\n",
|
| 249 |
+
"print(\"\\ntargets:\")\n",
|
| 250 |
+
"print(yb.shape)\n",
|
| 251 |
+
"print(yb)"
|
| 252 |
+
]
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"cell_type": "code",
|
| 256 |
+
"execution_count": null,
|
| 257 |
+
"id": "18810b27",
|
| 258 |
+
"metadata": {},
|
| 259 |
+
"outputs": [],
|
| 260 |
+
"source": [
|
| 261 |
+
"for b in range(batch_size):\n",
|
| 262 |
+
" for t in range(block_size):\n",
|
| 263 |
+
" context = xb[b, :t+1]\n",
|
| 264 |
+
" target = yb[b, t]\n",
|
| 265 |
+
" print(f\"when input is {context.tolist()} the target: {target.item()}\")"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"cell_type": "code",
|
| 270 |
+
"execution_count": 22,
|
| 271 |
+
"id": "77449b2f",
|
| 272 |
+
"metadata": {},
|
| 273 |
+
"outputs": [
|
| 274 |
+
{
|
| 275 |
+
"name": "stdout",
|
| 276 |
+
"output_type": "stream",
|
| 277 |
+
"text": [
|
| 278 |
+
"tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
|
| 279 |
+
" [44, 53, 56, 1, 58, 46, 39, 58],\n",
|
| 280 |
+
" [52, 58, 1, 58, 46, 39, 58, 1],\n",
|
| 281 |
+
" [25, 17, 27, 10, 0, 21, 1, 54]])\n",
|
| 282 |
+
"\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n",
|
| 285 |
+
" [53, 56, 1, 58, 46, 39, 58, 1],\n",
|
| 286 |
+
" [58, 1, 58, 46, 39, 58, 1, 46],\n",
|
| 287 |
+
" [17, 27, 10, 0, 21, 1, 54, 39]])\n"
|
| 288 |
+
]
|
| 289 |
+
}
|
| 290 |
+
],
|
| 291 |
+
"source": [
|
| 292 |
+
"print(xb)\n",
|
| 293 |
+
"print(\"\\n\")\n",
|
| 294 |
+
"print(yb)"
|
| 295 |
+
]
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"cell_type": "code",
|
| 299 |
+
"execution_count": null,
|
| 300 |
+
"id": "66a1c195",
|
| 301 |
+
"metadata": {},
|
| 302 |
+
"outputs": [
|
| 303 |
+
{
|
| 304 |
+
"name": "stdout",
|
| 305 |
+
"output_type": "stream",
|
| 306 |
+
"text": [
|
| 307 |
+
"logits shape: torch.Size([32, 65])\n",
|
| 308 |
+
"loss: 4.878634929656982\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"SKIcLT;AcE\n"
|
| 311 |
+
]
|
| 312 |
+
}
|
| 313 |
+
],
|
| 314 |
+
"source": [
|
| 315 |
+
"#Bigram language model\n",
|
| 316 |
+
"import torch\n",
|
| 317 |
+
"import torch.nn as nn\n",
|
| 318 |
+
"import torch.nn.functional as F\n",
|
| 319 |
+
"torch.manual_seed(1337)\n",
|
| 320 |
+
"class BigramLanguageModel(torch.nn.Module):\n",
|
| 321 |
+
" def __init__(self, vocab_size):\n",
|
| 322 |
+
" super().__init__()\n",
|
| 323 |
+
" self.token_embedding_table = torch.nn.Embedding(vocab_size, vocab_size)\n",
|
| 324 |
+
" def forward(self, idx, targets=None):\n",
|
| 325 |
+
" # idx and targets are both (B,T) tensor of integers\n",
|
| 326 |
+
" logits = self.token_embedding_table(idx) # (B,T,C)\n",
|
| 327 |
+
" if targets is None:\n",
|
| 328 |
+
" loss = None\n",
|
| 329 |
+
" else:\n",
|
| 330 |
+
" B,T,C = logits.shape\n",
|
| 331 |
+
" logits = logits.view(B*T, C)\n",
|
| 332 |
+
" targets = targets.view(B*T)\n",
|
| 333 |
+
" loss = F.cross_entropy(logits, targets)\n",
|
| 334 |
+
" return logits, loss\n",
|
| 335 |
+
" \n",
|
| 336 |
+
" def generate(self, idx, max_new_tokens):\n",
|
| 337 |
+
" # idx is (B,T) array of indices in the current context\n",
|
| 338 |
+
" for _ in range(max_new_tokens):\n",
|
| 339 |
+
" logits, loss = self(idx)\n",
|
| 340 |
+
" logits = logits[:, -1, :] # becomes (B,C) , as we only want to provide the last character as the input to predict the next character\n",
|
| 341 |
+
" probs = F.softmax(logits, dim=-1) # (B,C)\n",
|
| 342 |
+
" idx_next = torch.multinomial(probs, num_samples=1) # (B,1)\n",
|
| 343 |
+
" idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)\n",
|
| 344 |
+
" return idx\n",
|
| 345 |
+
" \n",
|
| 346 |
+
"model = BigramLanguageModel(vocab_size)\n",
|
| 347 |
+
"model₹.to(device)\n",
|
| 348 |
+
"logits, loss = model(xb, yb)\n",
|
| 349 |
+
"print(\"logits shape:\", logits.shape)\n",
|
| 350 |
+
"print(\"loss:\", loss.item())\n",
|
| 351 |
+
"print(decode(model.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=10)[0].tolist()))"
|
| 352 |
+
]
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"cell_type": "code",
|
| 356 |
+
"execution_count": 35,
|
| 357 |
+
"id": "ecd49fc4",
|
| 358 |
+
"metadata": {},
|
| 359 |
+
"outputs": [
|
| 360 |
+
{
|
| 361 |
+
"name": "stdout",
|
| 362 |
+
"output_type": "stream",
|
| 363 |
+
"text": [
|
| 364 |
+
"step 9999: loss 2.4313366413116455\n"
|
| 365 |
+
]
|
| 366 |
+
}
|
| 367 |
+
],
|
| 368 |
+
"source": [
|
| 369 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n",
|
| 370 |
+
"batch_size = 32\n",
|
| 371 |
+
"for steps in range(10000):\n",
|
| 372 |
+
" xb, yb = get_batch('train')\n",
|
| 373 |
+
" logits, loss = model(xb, yb)\n",
|
| 374 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
| 375 |
+
" loss.backward()\n",
|
| 376 |
+
" optimizer.step()\n",
|
| 377 |
+
"print(f\"step {steps}: loss {loss.item()}\")\n",
|
| 378 |
+
"\n"
|
| 379 |
+
]
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"cell_type": "code",
|
| 383 |
+
"execution_count": 36,
|
| 384 |
+
"id": "8ce29e5a",
|
| 385 |
+
"metadata": {},
|
| 386 |
+
"outputs": [
|
| 387 |
+
{
|
| 388 |
+
"name": "stdout",
|
| 389 |
+
"output_type": "stream",
|
| 390 |
+
"text": [
|
| 391 |
+
"\n",
|
| 392 |
+
"Warstyo a \n"
|
| 393 |
+
]
|
| 394 |
+
}
|
| 395 |
+
],
|
| 396 |
+
"source": [
|
| 397 |
+
"print(decode(model.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=10)[0].tolist()))"
|
| 398 |
+
]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"cell_type": "code",
|
| 402 |
+
"execution_count": null,
|
| 403 |
+
"id": "b87bd156",
|
| 404 |
+
"metadata": {},
|
| 405 |
+
"outputs": [
|
| 406 |
+
{
|
| 407 |
+
"name": "stdout",
|
| 408 |
+
"output_type": "stream",
|
| 409 |
+
"text": [
|
| 410 |
+
"using mps device\n"
|
| 411 |
+
]
|
| 412 |
+
}
|
| 413 |
+
],
|
| 414 |
+
"source": []
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"cell_type": "code",
|
| 418 |
+
"execution_count": null,
|
| 419 |
+
"id": "c9f2052b",
|
| 420 |
+
"metadata": {},
|
| 421 |
+
"outputs": [],
|
| 422 |
+
"source": []
|
| 423 |
+
}
|
| 424 |
+
],
|
| 425 |
+
"metadata": {
|
| 426 |
+
"kernelspec": {
|
| 427 |
+
"display_name": ".venv",
|
| 428 |
+
"language": "python",
|
| 429 |
+
"name": "python3"
|
| 430 |
+
},
|
| 431 |
+
"language_info": {
|
| 432 |
+
"codemirror_mode": {
|
| 433 |
+
"name": "ipython",
|
| 434 |
+
"version": 3
|
| 435 |
+
},
|
| 436 |
+
"file_extension": ".py",
|
| 437 |
+
"mimetype": "text/x-python",
|
| 438 |
+
"name": "python",
|
| 439 |
+
"nbconvert_exporter": "python",
|
| 440 |
+
"pygments_lexer": "ipython3",
|
| 441 |
+
"version": "3.14.2"
|
| 442 |
+
}
|
| 443 |
+
},
|
| 444 |
+
"nbformat": 4,
|
| 445 |
+
"nbformat_minor": 5
|
| 446 |
+
}
|