wip: add training pipeline 1
Browse files- article_base_train_test.py +80 -0
- article_base_tutorial.ipynb +289 -0
article_base_train_test.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import notebook_login
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig, TrainingArguments, Trainer
|
| 4 |
+
import torch
|
| 5 |
+
from peft import get_peft_model, LoraConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
ds = load_dataset('HuggingFaceM4/VQAv2', split="train", trust_remote_code=True)
|
| 10 |
+
cols_remove = ["question_type", "answers", "answer_type", "image_id", "question_id"]
|
| 11 |
+
ds = ds.remove_columns(cols_remove)
|
| 12 |
+
ds = ds.train_test_split(test_size=0.1)
|
| 13 |
+
train_ds = ds["train"]
|
| 14 |
+
val_ds = ds["test"]
|
| 15 |
+
|
| 16 |
+
model_id = "google/paligemma-3b-pt-224"
|
| 17 |
+
processor = PaliGemmaProcessor.from_pretrained(model_id)
|
| 18 |
+
image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
|
| 19 |
+
device = "cuda"
|
| 20 |
+
|
| 21 |
+
bnb_config = BitsAndBytesConfig(
|
| 22 |
+
load_in_4bit=True,
|
| 23 |
+
bnb_4bit_quant_type="nf4",
|
| 24 |
+
bnb_4bit_compute_type=torch.bfloat16
|
| 25 |
+
)
|
| 26 |
+
lora_config = LoraConfig(
|
| 27 |
+
r=8,
|
| 28 |
+
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
| 29 |
+
task_type="CAUSAL_LM",
|
| 30 |
+
)
|
| 31 |
+
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
|
| 32 |
+
model = get_peft_model(model, lora_config)
|
| 33 |
+
model.print_trainable_parameters()
|
| 34 |
+
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344
|
| 35 |
+
|
| 36 |
+
args=TrainingArguments(
|
| 37 |
+
num_train_epochs=2,
|
| 38 |
+
remove_unused_columns=False,
|
| 39 |
+
per_device_train_batch_size=16,
|
| 40 |
+
gradient_accumulation_steps=4,
|
| 41 |
+
warmup_steps=2,
|
| 42 |
+
learning_rate=2e-5,
|
| 43 |
+
weight_decay=1e-6,
|
| 44 |
+
adam_beta2=0.999,
|
| 45 |
+
logging_steps=100,
|
| 46 |
+
# optim="adamw_hf",
|
| 47 |
+
optim="paged_adamw_8bit", # for QLoRA
|
| 48 |
+
save_strategy="steps",
|
| 49 |
+
save_steps=1000,
|
| 50 |
+
push_to_hub=True,
|
| 51 |
+
save_total_limit=1,
|
| 52 |
+
bf16=True,
|
| 53 |
+
report_to=["tensorboard"],
|
| 54 |
+
dataloader_pin_memory=False
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def collate_fn(examples):
|
| 58 |
+
texts = ["answer " + example["question"] for example in examples]
|
| 59 |
+
labels= [example['multiple_choice_answer'] for example in examples] # μ°λ¦¬λ label μ΄ νμ μμλ―?
|
| 60 |
+
images = [example["image"].convert("RGB") for example in examples]
|
| 61 |
+
tokens = processor(text=texts, images=images, suffix=labels,
|
| 62 |
+
return_tensors="pt", padding="longest")
|
| 63 |
+
|
| 64 |
+
tokens = tokens.to(torch.bfloat16).to(device)
|
| 65 |
+
return tokens
|
| 66 |
+
|
| 67 |
+
trainer = Trainer(
|
| 68 |
+
model=model,
|
| 69 |
+
train_dataset=train_ds,
|
| 70 |
+
eval_dataset=val_ds,
|
| 71 |
+
data_collator=collate_fn,
|
| 72 |
+
args=args
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
trainer.train()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
notebook_login()
|
| 80 |
+
main()
|
article_base_tutorial.ipynb
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"data": {
|
| 10 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 11 |
+
"model_id": "4d8a398ca84a42d7b745d8e32e6ad3dd",
|
| 12 |
+
"version_major": 2,
|
| 13 |
+
"version_minor": 0
|
| 14 |
+
},
|
| 15 |
+
"text/plain": [
|
| 16 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svβ¦"
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"output_type": "display_data"
|
| 21 |
+
}
|
| 22 |
+
],
|
| 23 |
+
"source": [
|
| 24 |
+
"from huggingface_hub import notebook_login\n",
|
| 25 |
+
"notebook_login()"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "markdown",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"source": [
|
| 32 |
+
"# Load Dataset"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "code",
|
| 37 |
+
"execution_count": 2,
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"outputs": [
|
| 40 |
+
{
|
| 41 |
+
"name": "stderr",
|
| 42 |
+
"output_type": "stream",
|
| 43 |
+
"text": [
|
| 44 |
+
"Repo card metadata block was not found. Setting CardData to empty.\n"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"data": {
|
| 49 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 50 |
+
"model_id": "83c256fd38a143b6abded3fbf09d8bd8",
|
| 51 |
+
"version_major": 2,
|
| 52 |
+
"version_minor": 0
|
| 53 |
+
},
|
| 54 |
+
"text/plain": [
|
| 55 |
+
"Downloading data: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
"metadata": {},
|
| 59 |
+
"output_type": "display_data"
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"ename": "FSTimeoutError",
|
| 63 |
+
"evalue": "",
|
| 64 |
+
"output_type": "error",
|
| 65 |
+
"traceback": [
|
| 66 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 67 |
+
"\u001b[0;31mTimeoutError\u001b[0m Traceback (most recent call last)",
|
| 68 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/asyn.py:56\u001b[0m, in \u001b[0;36m_runner\u001b[0;34m(event, coro, result, timeout)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 56\u001b[0m result[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m coro\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m ex:\n",
|
| 69 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/implementations/http.py:262\u001b[0m, in \u001b[0;36mHTTPFileSystem._get_file\u001b[0;34m(self, rpath, lpath, chunk_size, callback, **kwargs)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m chunk:\n\u001b[0;32m--> 262\u001b[0m chunk \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m r\u001b[38;5;241m.\u001b[39mcontent\u001b[38;5;241m.\u001b[39mread(chunk_size)\n\u001b[1;32m 263\u001b[0m outfile\u001b[38;5;241m.\u001b[39mwrite(chunk)\n",
|
| 70 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/aiohttp/streams.py:396\u001b[0m, in \u001b[0;36mStreamReader.read\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_buffer \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_eof:\n\u001b[0;32m--> 396\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mread\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_nowait(n)\n",
|
| 71 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/aiohttp/streams.py:314\u001b[0m, in \u001b[0;36mStreamReader._wait\u001b[0;34m(self, func_name)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 314\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mwith\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_timer\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mawait\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mwaiter\u001b[49m\n",
|
| 72 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/aiohttp/helpers.py:719\u001b[0m, in \u001b[0;36mTimerContext.__exit__\u001b[0;34m(self, exc_type, exc_val, exc_tb)\u001b[0m\n\u001b[1;32m 718\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exc_type \u001b[38;5;129;01mis\u001b[39;00m asyncio\u001b[38;5;241m.\u001b[39mCancelledError \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cancelled:\n\u001b[0;32m--> 719\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m asyncio\u001b[38;5;241m.\u001b[39mTimeoutError \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 720\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 73 |
+
"\u001b[0;31mTimeoutError\u001b[0m: ",
|
| 74 |
+
"\nThe above exception was the direct cause of the following exception:\n",
|
| 75 |
+
"\u001b[0;31mFSTimeoutError\u001b[0m Traceback (most recent call last)",
|
| 76 |
+
"Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset \n\u001b[0;32m----> 2\u001b[0m ds \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mHuggingFaceM4/VQAv2\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m \n\u001b[1;32m 3\u001b[0m cols_remove \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquestion_type\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124manswers\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124manswer_type\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimage_id\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquestion_id\u001b[39m\u001b[38;5;124m\"\u001b[39m] \n\u001b[1;32m 4\u001b[0m ds \u001b[38;5;241m=\u001b[39m ds\u001b[38;5;241m.\u001b[39mremove_columns(cols_remove)\n",
|
| 77 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/load.py:2096\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, keep_in_memory, save_infos, revision, token, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2093\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m builder_instance\u001b[38;5;241m.\u001b[39mas_streaming_dataset(split\u001b[38;5;241m=\u001b[39msplit)\n\u001b[1;32m 2095\u001b[0m \u001b[38;5;66;03m# Download and prepare data\u001b[39;00m\n\u001b[0;32m-> 2096\u001b[0m \u001b[43mbuilder_instance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2097\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2098\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2099\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2100\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2101\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2102\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2104\u001b[0m \u001b[38;5;66;03m# Build dataset for splits\u001b[39;00m\n\u001b[1;32m 2105\u001b[0m keep_in_memory \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2106\u001b[0m keep_in_memory \u001b[38;5;28;01mif\u001b[39;00m keep_in_memory \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m is_small_dataset(builder_instance\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size)\n\u001b[1;32m 2107\u001b[0m )\n",
|
| 78 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/builder.py:924\u001b[0m, in \u001b[0;36mDatasetBuilder.download_and_prepare\u001b[0;34m(self, output_dir, download_config, download_mode, verification_mode, dl_manager, base_path, file_format, max_shard_size, num_proc, storage_options, **download_and_prepare_kwargs)\u001b[0m\n\u001b[1;32m 922\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_proc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 923\u001b[0m prepare_split_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_proc\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m num_proc\n\u001b[0;32m--> 924\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 925\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 926\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 927\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 928\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdownload_and_prepare_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 929\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 930\u001b[0m \u001b[38;5;66;03m# Sync info\u001b[39;00m\n\u001b[1;32m 931\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(split\u001b[38;5;241m.\u001b[39mnum_bytes \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39msplits\u001b[38;5;241m.\u001b[39mvalues())\n",
|
| 79 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/builder.py:1647\u001b[0m, in \u001b[0;36mGeneratorBasedBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_splits_kwargs)\u001b[0m\n\u001b[1;32m 1646\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_download_and_prepare\u001b[39m(\u001b[38;5;28mself\u001b[39m, dl_manager, verification_mode, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mprepare_splits_kwargs):\n\u001b[0;32m-> 1647\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1648\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1649\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1650\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_duplicate_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mVerificationMode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mBASIC_CHECKS\u001b[49m\n\u001b[1;32m 1651\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mVerificationMode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mALL_CHECKS\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1652\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_splits_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1653\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
| 80 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/builder.py:977\u001b[0m, in \u001b[0;36mDatasetBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_split_kwargs)\u001b[0m\n\u001b[1;32m 975\u001b[0m split_dict \u001b[38;5;241m=\u001b[39m SplitDict(dataset_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset_name)\n\u001b[1;32m 976\u001b[0m split_generators_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_split_generators_kwargs(prepare_split_kwargs)\n\u001b[0;32m--> 977\u001b[0m split_generators \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_split_generators\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msplit_generators_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# Checksums verification\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m verification_mode \u001b[38;5;241m==\u001b[39m VerificationMode\u001b[38;5;241m.\u001b[39mALL_CHECKS \u001b[38;5;129;01mand\u001b[39;00m dl_manager\u001b[38;5;241m.\u001b[39mrecord_checksums:\n",
|
| 81 |
+
"File \u001b[0;32m~/.cache/huggingface/modules/datasets_modules/datasets/HuggingFaceM4--VQAv2/e4d008385143be7a6bd81e99483e671d5096942bcb987542217121a5ac2cb420/VQAv2.py:118\u001b[0m, in \u001b[0;36mVQAv2Dataset._split_generators\u001b[0;34m(self, dl_manager)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_split_generators\u001b[39m(\u001b[38;5;28mself\u001b[39m, dl_manager):\n\u001b[1;32m 117\u001b[0m \u001b[38;5;66;03m# urls = _URLS[self.config.name] # TODO later\u001b[39;00m\n\u001b[0;32m--> 118\u001b[0m data_dir \u001b[38;5;241m=\u001b[39m \u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_extract\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_URLS\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m gen_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 120\u001b[0m split_name: {\n\u001b[1;32m 121\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdir_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_path\u001b[39m\u001b[38;5;124m\"\u001b[39m: Path(data_dir[dir_name][split_name])\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m split_name \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest-dev\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 128\u001b[0m }\n\u001b[1;32m 129\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 130\u001b[0m datasets\u001b[38;5;241m.\u001b[39mSplitGenerator(\n\u001b[1;32m 131\u001b[0m name\u001b[38;5;241m=\u001b[39mdatasets\u001b[38;5;241m.\u001b[39mSplit\u001b[38;5;241m.\u001b[39mTRAIN,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 145\u001b[0m ),\n\u001b[1;32m 146\u001b[0m ]\n",
|
| 82 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:322\u001b[0m, in \u001b[0;36mDownloadManager.download_and_extract\u001b[0;34m(self, url_or_urls)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdownload_and_extract\u001b[39m(\u001b[38;5;28mself\u001b[39m, url_or_urls):\n\u001b[1;32m 307\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Download and extract given `url_or_urls`.\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \n\u001b[1;32m 309\u001b[0m \u001b[38;5;124;03m Is roughly equivalent to:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124;03m extracted_path(s): `str`, extracted paths of given URL(s).\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 322\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mextract(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl_or_urls\u001b[49m\u001b[43m)\u001b[49m)\n",
|
| 83 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:159\u001b[0m, in \u001b[0;36mDownloadManager.download\u001b[0;34m(self, url_or_urls)\u001b[0m\n\u001b[1;32m 157\u001b[0m start_time \u001b[38;5;241m=\u001b[39m datetime\u001b[38;5;241m.\u001b[39mnow()\n\u001b[1;32m 158\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m stack_multiprocessing_download_progress_bars():\n\u001b[0;32m--> 159\u001b[0m downloaded_path_or_paths \u001b[38;5;241m=\u001b[39m \u001b[43mmap_nested\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 160\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[43m \u001b[49m\u001b[43murl_or_urls\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[43m \u001b[49m\u001b[43mmap_tuple\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mDownloading data files\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 165\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 167\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m duration \u001b[38;5;241m=\u001b[39m datetime\u001b[38;5;241m.\u001b[39mnow() \u001b[38;5;241m-\u001b[39m start_time\n\u001b[1;32m 169\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownloading took \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mduration\u001b[38;5;241m.\u001b[39mtotal_seconds()\u001b[38;5;250m \u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m60\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m min\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
| 84 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:512\u001b[0m, in \u001b[0;36mmap_nested\u001b[0;34m(function, data_struct, dict_only, map_list, map_tuple, map_numpy, num_proc, parallel_min_length, batched, batch_size, types, disable_tqdm, desc)\u001b[0m\n\u001b[1;32m 509\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\u001b[38;5;28mlen\u001b[39m(iterable) \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m num_proc \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(iterable) \u001b[38;5;241m%\u001b[39m num_proc \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m), \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 510\u001b[0m iterable \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(iter_batched(iterable, batch_size))\n\u001b[1;32m 511\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m--> 512\u001b[0m \u001b[43m_single_map_nested\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 513\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m hf_tqdm(iterable, disable\u001b[38;5;241m=\u001b[39mdisable_tqdm, desc\u001b[38;5;241m=\u001b[39mdesc)\n\u001b[1;32m 514\u001b[0m ]\n\u001b[1;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batched:\n\u001b[1;32m 516\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [mapped_item \u001b[38;5;28;01mfor\u001b[39;00m mapped_batch \u001b[38;5;129;01min\u001b[39;00m mapped \u001b[38;5;28;01mfor\u001b[39;00m mapped_item \u001b[38;5;129;01min\u001b[39;00m mapped_batch]\n",
|
| 85 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:399\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[1;32m 396\u001b[0m k: _single_map_nested((function, v, batched, batch_size, types, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m pbar\n\u001b[1;32m 397\u001b[0m }\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 399\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [\u001b[43m_single_map_nested\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m pbar]\n\u001b[1;32m 400\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m mapped\n",
|
| 86 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:396\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m hf_tqdm(pbar_iterable, disable\u001b[38;5;241m=\u001b[39mdisable_tqdm, position\u001b[38;5;241m=\u001b[39mrank, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobj\u001b[39m\u001b[38;5;124m\"\u001b[39m, desc\u001b[38;5;241m=\u001b[39mpbar_desc) \u001b[38;5;28;01mas\u001b[39;00m pbar:\n\u001b[1;32m 394\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[0;32m--> 396\u001b[0m k: \u001b[43m_single_map_nested\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m pbar\n\u001b[1;32m 397\u001b[0m }\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 399\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [_single_map_nested((function, v, batched, batch_size, types, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m pbar]\n",
|
| 87 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:371\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, types):\n\u001b[1;32m 370\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batched:\n\u001b[0;32m--> 371\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdata_struct\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 372\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 373\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m function(data_struct)\n",
|
| 88 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:216\u001b[0m, in \u001b[0;36mDownloadManager._download_batched\u001b[0;34m(self, url_or_filenames, download_config)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m thread_map(\n\u001b[1;32m 203\u001b[0m download_func,\n\u001b[1;32m 204\u001b[0m url_or_filenames,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 212\u001b[0m tqdm_class\u001b[38;5;241m=\u001b[39mtqdm,\n\u001b[1;32m 213\u001b[0m )\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_single\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl_or_filename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m url_or_filename \u001b[38;5;129;01min\u001b[39;00m url_or_filenames\n\u001b[1;32m 218\u001b[0m ]\n",
|
| 89 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:225\u001b[0m, in \u001b[0;36mDownloadManager._download_single\u001b[0;34m(self, url_or_filename, download_config)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_relative_path(url_or_filename):\n\u001b[1;32m 223\u001b[0m \u001b[38;5;66;03m# append the relative path to the base_path\u001b[39;00m\n\u001b[1;32m 224\u001b[0m url_or_filename \u001b[38;5;241m=\u001b[39m url_or_path_join(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_base_path, url_or_filename)\n\u001b[0;32m--> 225\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mcached_path\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl_or_filename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 226\u001b[0m out \u001b[38;5;241m=\u001b[39m tracked_str(out)\n\u001b[1;32m 227\u001b[0m out\u001b[38;5;241m.\u001b[39mset_origin(url_or_filename)\n",
|
| 90 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/file_utils.py:205\u001b[0m, in \u001b[0;36mcached_path\u001b[0;34m(url_or_filename, download_config, **download_kwargs)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\u001b[38;5;28mstr\u001b[39m(e)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;66;03m# Download external files\u001b[39;00m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 205\u001b[0m output_path \u001b[38;5;241m=\u001b[39m \u001b[43mget_from_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[43m \u001b[49m\u001b[43murl_or_filename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 208\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 209\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 210\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_etag\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_etag\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 211\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 212\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_desc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_desc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisable_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdisable_tqdm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 215\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mexists(url_or_filename):\n\u001b[1;32m 217\u001b[0m \u001b[38;5;66;03m# File, and it exists.\u001b[39;00m\n\u001b[1;32m 218\u001b[0m output_path \u001b[38;5;241m=\u001b[39m url_or_filename\n",
|
| 91 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/file_utils.py:415\u001b[0m, in \u001b[0;36mget_from_cache\u001b[0;34m(url, cache_dir, force_download, user_agent, use_etag, token, storage_options, download_desc, disable_tqdm)\u001b[0m\n\u001b[1;32m 413\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not found in cache or force_download set to True, downloading to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtemp_file\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 414\u001b[0m \u001b[38;5;66;03m# GET file object\u001b[39;00m\n\u001b[0;32m--> 415\u001b[0m \u001b[43mfsspec_get\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemp_file\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_desc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisable_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisable_tqdm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 417\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstoring \u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in cache at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcache_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 418\u001b[0m shutil\u001b[38;5;241m.\u001b[39mmove(temp_file\u001b[38;5;241m.\u001b[39mname, cache_path)\n",
|
| 92 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/file_utils.py:334\u001b[0m, in \u001b[0;36mfsspec_get\u001b[0;34m(url, temp_file, storage_options, desc, disable_tqdm)\u001b[0m\n\u001b[1;32m 321\u001b[0m fs, path \u001b[38;5;241m=\u001b[39m url_to_fs(url, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m(storage_options \u001b[38;5;129;01mor\u001b[39;00m {}))\n\u001b[1;32m 322\u001b[0m callback \u001b[38;5;241m=\u001b[39m TqdmCallback(\n\u001b[1;32m 323\u001b[0m tqdm_kwargs\u001b[38;5;241m=\u001b[39m{\n\u001b[1;32m 324\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdesc\u001b[39m\u001b[38;5;124m\"\u001b[39m: desc \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownloading\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 332\u001b[0m }\n\u001b[1;32m 333\u001b[0m )\n\u001b[0;32m--> 334\u001b[0m \u001b[43mfs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemp_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 93 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/asyn.py:118\u001b[0m, in \u001b[0;36msync_wrapper.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m obj \u001b[38;5;129;01mor\u001b[39;00m args[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msync\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 94 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/asyn.py:101\u001b[0m, in \u001b[0;36msync\u001b[0;34m(loop, func, timeout, *args, **kwargs)\u001b[0m\n\u001b[1;32m 98\u001b[0m return_result \u001b[38;5;241m=\u001b[39m result[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(return_result, asyncio\u001b[38;5;241m.\u001b[39mTimeoutError):\n\u001b[1;32m 100\u001b[0m \u001b[38;5;66;03m# suppress asyncio.TimeoutError, raise FSTimeoutError\u001b[39;00m\n\u001b[0;32m--> 101\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m FSTimeoutError \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mreturn_result\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(return_result, \u001b[38;5;167;01mBaseException\u001b[39;00m):\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m return_result\n",
|
| 95 |
+
"\u001b[0;31mFSTimeoutError\u001b[0m: "
|
| 96 |
+
]
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"source": [
|
| 100 |
+
"from datasets import load_dataset \n",
|
| 101 |
+
"ds = load_dataset('HuggingFaceM4/VQAv2', split=\"train\", trust_remote_code=True) \n",
|
| 102 |
+
"cols_remove = [\"question_type\", \"answers\", \"answer_type\", \"image_id\", \"question_id\"] \n",
|
| 103 |
+
"ds = ds.remove_columns(cols_remove)\n",
|
| 104 |
+
"ds = ds.train_test_split(test_size=0.1)\n",
|
| 105 |
+
"train_ds = ds[\"train\"]\n",
|
| 106 |
+
"val_ds = ds[\"test\"]"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "markdown",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"source": [
|
| 113 |
+
"# Train (QLoRA 4-bit)"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "code",
|
| 118 |
+
"execution_count": null,
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"outputs": [],
|
| 121 |
+
"source": [
|
| 122 |
+
"from transformers import PaliGemmaProcessor \n",
|
| 123 |
+
"model_id = \"google/paligemma-3b-pt-224\"\n",
|
| 124 |
+
"processor = PaliGemmaProcessor.from_pretrained(model_id)"
|
| 125 |
+
]
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"cell_type": "code",
|
| 129 |
+
"execution_count": null,
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"outputs": [],
|
| 132 |
+
"source": [
|
| 133 |
+
"import torch\n",
|
| 134 |
+
"device = \"cuda\"\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"image_token = processor.tokenizer.convert_tokens_to_ids(\"<image>\")\n",
|
| 137 |
+
"def collate_fn(examples):\n",
|
| 138 |
+
" texts = [\"answer \" + example[\"question\"] for example in examples]\n",
|
| 139 |
+
" labels= [example['multiple_choice_answer'] for example in examples] # μ°λ¦¬λ label μ΄ νμ μμλ―?\n",
|
| 140 |
+
" images = [example[\"image\"].convert(\"RGB\") for example in examples]\n",
|
| 141 |
+
" tokens = processor(text=texts, images=images, suffix=labels,\n",
|
| 142 |
+
" return_tensors=\"pt\", padding=\"longest\")\n",
|
| 143 |
+
"\n",
|
| 144 |
+
" tokens = tokens.to(torch.bfloat16).to(device)\n",
|
| 145 |
+
" return tokens"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "code",
|
| 150 |
+
"execution_count": null,
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"outputs": [],
|
| 153 |
+
"source": [
|
| 154 |
+
"from transformers import PaliGemmaForConditionalGeneration\n",
|
| 155 |
+
"model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"# Freeze the vision tower and the transformer encoder (image encoder)\n",
|
| 158 |
+
"for param in model.vision_tower.parameters():\n",
|
| 159 |
+
" param.requires_grad = False\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"# Projector is not frozen (Article μμλ freeze νλ€κ³ λμ΄μμ)\n",
|
| 162 |
+
"for param in model.multi_modal_projector.parameters():\n",
|
| 163 |
+
" param.requires_grad = True\n"
|
| 164 |
+
]
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"cell_type": "markdown",
|
| 168 |
+
"metadata": {},
|
| 169 |
+
"source": [
|
| 170 |
+
"For QLoRa in 4-bit"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "code",
|
| 175 |
+
"execution_count": null,
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"outputs": [],
|
| 178 |
+
"source": [
|
| 179 |
+
"from transformers import BitsAndBytesConfig\n",
|
| 180 |
+
"from peft import get_peft_model, LoraConfig\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
| 183 |
+
" load_in_4bit=True,\n",
|
| 184 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 185 |
+
" bnb_4bit_compute_type=torch.bfloat16\n",
|
| 186 |
+
")\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"lora_config = LoraConfig(\n",
|
| 189 |
+
" r=8, \n",
|
| 190 |
+
" target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 191 |
+
" task_type=\"CAUSAL_LM\",\n",
|
| 192 |
+
")\n",
|
| 193 |
+
"model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={\"\":0})\n",
|
| 194 |
+
"model = get_peft_model(model, lora_config)\n",
|
| 195 |
+
"model.print_trainable_parameters()\n",
|
| 196 |
+
"#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344\n"
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"cell_type": "code",
|
| 201 |
+
"execution_count": null,
|
| 202 |
+
"metadata": {},
|
| 203 |
+
"outputs": [],
|
| 204 |
+
"source": [
|
| 205 |
+
"from transformers import TrainingArguments\n",
|
| 206 |
+
"args=TrainingArguments(\n",
|
| 207 |
+
" num_train_epochs=2,\n",
|
| 208 |
+
" remove_unused_columns=False,\n",
|
| 209 |
+
" per_device_train_batch_size=16,\n",
|
| 210 |
+
" gradient_accumulation_steps=4,\n",
|
| 211 |
+
" warmup_steps=2,\n",
|
| 212 |
+
" learning_rate=2e-5,\n",
|
| 213 |
+
" weight_decay=1e-6,\n",
|
| 214 |
+
" adam_beta2=0.999,\n",
|
| 215 |
+
" logging_steps=100,\n",
|
| 216 |
+
" # optim=\"adamw_hf\",\n",
|
| 217 |
+
" optim=\"paged_adamw_8bit\", # for QLoRA\n",
|
| 218 |
+
" save_strategy=\"steps\",\n",
|
| 219 |
+
" save_steps=1000,\n",
|
| 220 |
+
" push_to_hub=True,\n",
|
| 221 |
+
" save_total_limit=1,\n",
|
| 222 |
+
" bf16=True,\n",
|
| 223 |
+
" report_to=[\"tensorboard\"],\n",
|
| 224 |
+
" dataloader_pin_memory=False\n",
|
| 225 |
+
" )\n"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "code",
|
| 230 |
+
"execution_count": null,
|
| 231 |
+
"metadata": {},
|
| 232 |
+
"outputs": [],
|
| 233 |
+
"source": [
|
| 234 |
+
"from transformers import Trainer\n",
|
| 235 |
+
"trainer = Trainer(\n",
|
| 236 |
+
" model=model,\n",
|
| 237 |
+
" train_dataset=train_ds,\n",
|
| 238 |
+
" eval_dataset=val_ds,\n",
|
| 239 |
+
" data_collator=collate_fn,\n",
|
| 240 |
+
" args=args\n",
|
| 241 |
+
" )"
|
| 242 |
+
]
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"cell_type": "code",
|
| 246 |
+
"execution_count": null,
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"outputs": [],
|
| 249 |
+
"source": [
|
| 250 |
+
"trainer.train()"
|
| 251 |
+
]
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"cell_type": "markdown",
|
| 255 |
+
"metadata": {},
|
| 256 |
+
"source": [
|
| 257 |
+
"# Inference for test"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "code",
|
| 262 |
+
"execution_count": null,
|
| 263 |
+
"metadata": {},
|
| 264 |
+
"outputs": [],
|
| 265 |
+
"source": []
|
| 266 |
+
}
|
| 267 |
+
],
|
| 268 |
+
"metadata": {
|
| 269 |
+
"kernelspec": {
|
| 270 |
+
"display_name": "Python 3",
|
| 271 |
+
"language": "python",
|
| 272 |
+
"name": "python3"
|
| 273 |
+
},
|
| 274 |
+
"language_info": {
|
| 275 |
+
"codemirror_mode": {
|
| 276 |
+
"name": "ipython",
|
| 277 |
+
"version": 3
|
| 278 |
+
},
|
| 279 |
+
"file_extension": ".py",
|
| 280 |
+
"mimetype": "text/x-python",
|
| 281 |
+
"name": "python",
|
| 282 |
+
"nbconvert_exporter": "python",
|
| 283 |
+
"pygments_lexer": "ipython3",
|
| 284 |
+
"version": "3.12.2"
|
| 285 |
+
}
|
| 286 |
+
},
|
| 287 |
+
"nbformat": 4,
|
| 288 |
+
"nbformat_minor": 2
|
| 289 |
+
}
|