Spaces:
Paused
Paused
# Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset) | |
 | |
## Overview | |
This guide walks you through the process of fine-tuning a multimodal language model (e.g., **Gemma 3**) using **Supervised Fine-Tuning (SFT)**. We cover two cases: | |
- **Single Image + Text** | |
- **Multi-Image + Text** | |
This guide serves as a **detailed walkthrough** and complements the existing [VLM SFT script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py). If you're already familiar with the concepts, you can use the script directly. | |
We demonstrate the fine-tuning process using two datasets, but these principles extend to other **Vision-Language Models (VLMs)** and datasets. | |
## Understanding the Datasets | |
To address both **Single Image + Text** and **Multi-Image + Text** scenarios, we use two datasets that are well-suited for this task. | |
### HuggingFaceH4/llava-instruct-mix-vsft Dataset (Image + Text) | |
This dataset is a reformatted version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). It consists of conversations where a user provides both **text** and a **single image** as input. | |
The model (referred to as the **"assistant"**) responds based on both the **visual and textual information** shared by the user. This dataset is particularly useful for training multimodal models to **understand and generate responses based on images and text**. | |
<iframe | |
src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train" | |
frameborder="0" | |
width="100%" | |
height="560px" | |
></iframe> | |
### FanqingM/MMIU-Benchmark Dataset (Multi-Image + Text) | |
The **FanqingM/MMIU-Benchmark** dataset consists of: | |
- **Context:** Included in the system prompt. | |
- **Question:** Provided as part of the user's input. | |
- **Series of Images:** Multiple images related to the question. | |
- **Answer:** The model's expected response. | |
This dataset is designed for tasks where the model must reason over multiple images to generate an informed response based on both visual and textual inputs. | |
<iframe | |
src="https://huggingface.co/datasets/FanqingM/MMIU-Benchmark/embed/viewer/default/test" | |
frameborder="0" | |
width="100%" | |
height="560px" | |
></iframe> | |
## Developing a Fine-Tuning Script for Multimodal SFT | |
In this section, we build the script needed to fine-tune a multimodal model for both **Single Image + Text** and **Multi-Image + Text** use cases. | |
### Setting Up the Environment | |
Before fine-tuning, we need to install the required dependencies. Let's start by setting up the environment: | |
```bash | |
# Install the required libraries. Futher details: https://huggingface.co/docs/trl/installation | |
pip install -U -q trl bitsandbytes peft hf_xet tensorboard | |
``` | |
Once all dependencies are installed, we need to log in to the **Hugging Face Hub**. Since **Gemma 3** is a gated model, access permissions are required. | |
If you haven’t requested access yet, visit the [Model Card](https://huggingface.co/google/gemma-3-4b-it) and request it. | |
To log in, you’ll need to generate an [access token](https://huggingface.co/settings/tokens) from your Hugging Face account. | |
```bash | |
huggingface-cli login | |
``` | |
### **Loading the Data** | |
As mentioned earlier, we will cover two possible use cases. While the specific procedure may vary based on the dataset, the core principles remain consistent. | |
This guide supports both use cases, so refer to the **Single Image + Text** or **Multi-Image + Text** sections depending on your specific scenario. | |
#### **Single Image + Text** | |
 | |
In this case, each sample in a batch consists of a **single image paired with text**. Since the dataset is already formatted for supervised fine-tuning (SFT), we can directly load it using `load_dataset`. | |
```python | |
from datasets import load_dataset | |
dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft" | |
# Load Dataset | |
dataset = load_dataset(dataset_name) | |
``` | |
#### **Multi-Image + Text (or Interleaving)** | |
 | |
Gemma 3 also supports **Multi-Image + Text** scenarios, where: | |
- The model receives a **list of images** alongside a user message. | |
- The model processes **interleaved images and text** within a conversation. | |
For this dataset, some preprocessing is required before training. | |
```python | |
from datasets import load_dataset | |
dataset_name = "FanqingM/MMIU-Benchmark" | |
# Load Dataset | |
dataset = load_dataset(dataset_name) | |
``` | |
After loading the dataset, we need to preprocess and format it into a conversational structure. Here’s an example of how the data might look: | |
```python | |
{"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]}, | |
{"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]}, | |
{"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]}, | |
``` | |
Here, `images_list` is a list of images: | |
```python | |
images_list = [ | |
{"type": "image", "image": <class 'PIL.Image.Image'>}, | |
{"type": "image", "image": <class 'PIL.Image.Image'>}, | |
{"type": "image", "image": <class 'PIL.Image.Image'>}, | |
{"type": "image", "image": <class 'PIL.Image.Image'>}, | |
{"type": "image", "image": <class 'PIL.Image.Image'>}, | |
] | |
``` | |
This structure can be translated into code like this: | |
```python | |
import os | |
import zipfile | |
import io | |
from datasets import DatasetDict | |
from huggingface_hub import hf_hub_download, list_repo_files | |
from PIL import Image | |
dataset_train_split = "test" | |
def format_data(samples: dict[str, any]) -> dict[str, list]: | |
formatted_samples = {"messages": []} | |
for cont in range(len(samples["question"])): | |
images = [] | |
for img_path in samples["input_image_path"][cont]: | |
try: | |
with open(img_path, "rb") as f: | |
img_bytes = f.read() | |
image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
images.append({"type": "image", "image": image}) | |
except Exception as e: | |
print(f"Error processing image {img_path}: {e}") | |
continue | |
formatted_samples["messages"].append( | |
[ | |
{"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]}, | |
{"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]}, | |
{"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]}, | |
] | |
) | |
return formatted_samples | |
# For multi-image example | |
def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict: | |
all_files = list_repo_files(dataset_name, repo_type="dataset") | |
zip_files = [f for f in all_files if f.endswith(".zip")] | |
for zip_filename in zip_files: | |
zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset") | |
extract_folder = zip_filename.replace(".zip", "") | |
os.makedirs(extract_folder, exist_ok=True) | |
with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
zip_ref.extractall(extract_folder) | |
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16) | |
return dataset | |
dataset = prepare_dataset(dataset, dataset_name, dataset_train_split) | |
``` | |
With this, your **Multi-Image + Text** dataset is now prepared for training. | |
### **Preparing for Training** | |
We start by loading the model and processor. In this example, we use `google/gemma-3-4b-it`, but the same process applies to its other variants and similar models. | |
To optimize memory usage, we configure `BitsAndBytes` to load the quantized version of the model. | |
```python | |
import torch | |
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig | |
model_id = "google/gemma-3-4b-it" | |
# BitsAndBytesConfig int-4 config | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_storage=torch.bfloat16, | |
) | |
# Load model and tokenizer | |
model = AutoModelForImageTextToText.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934) | |
quantization_config=bnb_config | |
) | |
processor = AutoProcessor.from_pretrained(model_id) | |
processor.tokenizer.padding_side = "right" | |
``` | |
Next, we set up [Quantized Low-Rank Adaptation (QLoRA)](https://huggingface.co/papers/2305.14314), an efficient fine-tuning technique for Large Language Models (LLMs) and Vision-Language Models (VLMs). | |
```python | |
from peft import LoraConfig, get_peft_model | |
# Configure QLoRA | |
peft_config = LoraConfig( | |
lora_alpha=16, | |
lora_dropout=0.05, | |
r=16, | |
bias="none", | |
target_modules="all-linear", | |
task_type="CAUSAL_LM", | |
modules_to_save=[ | |
"lm_head", | |
"embed_tokens", | |
], | |
) | |
``` | |
With QLoRA now set up, we need to define the training arguments for SFT. The [`SFTConfig`] class simplifies this process, providing an easy way to adjust parameters based on our specific needs. | |
```python | |
from trl import SFTConfig | |
training_args = SFTConfig( | |
output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets). | |
num_train_epochs=1, # Set the number of epochs to train the model. | |
per_device_train_batch_size=8, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1 | |
gradient_accumulation_steps=4, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1 | |
gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training. | |
optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance. | |
logging_steps=10, # Frequency of logging training progress (log every 10 steps). | |
save_strategy="epoch", # Save checkpoints at the end of each epoch. | |
learning_rate=2e-05, # Learning rate for training. | |
bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations. | |
push_to_hub=True, # Automatically push the fine-tuned model to Hugging Face Hub after training. | |
report_to="tensorboard", # Automatically report metrics to tensorboard. | |
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues. | |
dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually. | |
remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing). | |
) | |
``` | |
The `collate_fn` is responsible for processing and preparing individual examples to form a batch. | |
Each example in the batch undergoes the following steps: | |
1. The **chat template** is applied to the text. | |
2. The **processor tokenizes** both `texts` and `images`, encoding them into tensors. | |
3. The **labels** for training are set as the `input_ids` of the example. | |
4. Certain **special tokens** are **masked (ignored)** during loss computation: | |
- `pad_token_id` | |
- `<image_token_id>` | |
- `<image_soft_token>` (corresponding to ID `262144`) | |
This process is similar across different dataset types, with a minor variation in how images are handled: | |
- **Single Image + Text** → A **list of images** is directly processed. | |
- **Multi-Image + Text** → A **list of lists of images** is used, where each batch element contains multiple images. | |
```python | |
from PIL import Image | |
# For multi-image cases | |
def process_vision_info(messages: list[dict]) -> list[Image.Image]: | |
image_inputs = [] | |
for msg in messages: | |
content = msg.get("content", []) | |
if not isinstance(content, list): | |
content = [content] | |
for element in content: | |
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"): | |
if "image" in element: | |
image = element["image"] | |
else: | |
image = element | |
if image is not None: | |
image = Image.open(io.BytesIO(image["bytes"])) | |
image_inputs.append(image.convert("RGB")) | |
return image_inputs | |
def collate_fn(examples): | |
texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples] | |
if "images" in examples[0]: # single-image | |
images = [ | |
[img.convert("RGB") for img in example["images"]] | |
for example in examples | |
] | |
else: # multi-image | |
images = [process_vision_info(example["messages"]) for example in examples] | |
# Tokenize the texts and process the images | |
batch = processor( | |
text=texts, images=images, return_tensors="pt", padding=True | |
) # Encode texts and images into tensors | |
# The labels are the input_ids, and we mask the padding tokens in the loss computation | |
labels = batch["input_ids"].clone() # Clone input IDs for labels | |
# Mask image tokens | |
image_token_id = [ | |
processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"]) | |
] | |
# Mask tokens for not being used in the loss computation | |
labels[labels == processor.tokenizer.pad_token_id] = -100 | |
labels[labels == image_token_id] = -100 | |
labels[labels == 262144] = -100 | |
batch["labels"] = labels | |
return batch # Return the prepared batch | |
``` | |
### **Training the Model** | |
With all the components set up, we now configure the `SFTTrainer` using the previously defined settings and start the training process. | |
``` python | |
# Training | |
from trl import SFTTrainer | |
trainer = SFTTrainer( | |
model=model, | |
args=training_args, | |
data_collator=collate_fn, | |
train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"], | |
processing_class=processor, | |
peft_config=peft_config, | |
) | |
trainer.train() | |
# Save the final model | |
trainer.save_model() | |
``` | |
We save the fine-tuned model to the Hub, making it easily accessible for future use. Additionally, TRL automatically logs the training results to **Weights & Biases (Wandb)** or **TensorBoard**, depending on the chosen configuration. | |
<!-- Add Wandb training results --> | |
### Results | |
During and after trainig, we can inspect the results using **Weights & Biases (Wandb)** or **TensorBoard**. For example: | |
* [**gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft (Single Image+Text)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft) | |
* [**gemma-3-4b-it-trl-sft-MMIU-Benchmark (Multi-Images+Text or Interleaving)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-MMIU-Benchmark) | |
## Limitations | |
Currently, fine-tuning Gemma has some [known limitations](https://github.com/huggingface/trl/issues/3121). We recommend following the procedure outlined in this guide to ensure the best results. | |
## References | |
For further reading and complementary resources, check out the following: | |
- [Fine-Tuning Vision-Language Models with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora) | |
- [Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | |