diff --git "a/train_on_parquet.ipynb" "b/train_on_parquet.ipynb"
--- "a/train_on_parquet.ipynb"
+++ "b/train_on_parquet.ipynb"
@@ -1 +1 @@
-{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/parquet_explorer.ipynb","timestamp":1754497857381},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754475181338},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754312448728},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754310418707},{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1754223895158},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1747490904984},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1740037333374},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1736477078136},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1725365086834}]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["Download a parquet file to your Google drive and load it from there into this notebook.\n","\n","Parquet files: https://huggingface.co/datasets/codeShare/chroma_prompts/tree/main\n","\n","E621 JSON files: https://huggingface.co/datasets/lodestones/e621-captions/tree/main"],"metadata":{"id":"LeCfcqgiQvCP"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"HFy5aDxM3G7O"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KYv7Y2gNPW_i"},"outputs":[],"source":["#@markdown Investigate a json file\n","\n","import json\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Initialize lists to store data\n","data = []\n","\n","# Read the .jsonl file line by line\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," # Parse each line as a JSON object\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","\n","# Convert the list of JSON objects to a Pandas DataFrame for easier exploration\n","df = pd.DataFrame(data)\n","\n","# Display basic information about the DataFrame\n","print(\"=== File Overview ===\")\n","print(f\"Number of records: {len(df)}\")\n","print(\"\\nColumn names:\")\n","print(df.columns.tolist())\n","print(\"\\nData types:\")\n","print(df.dtypes)\n","\n","# Display the first few rows\n","print(\"\\n=== First 5 Rows ===\")\n","print(df.head())\n","\n","# Display basic statistics\n","print(\"\\n=== Basic Statistics ===\")\n","print(df.describe(include='all'))\n","\n","# Check for missing values\n","print(\"\\n=== Missing Values ===\")\n","print(df.isnull().sum())\n","\n","# Optional: Display unique values in each column\n","print(\"\\n=== Unique Values per Column ===\")\n","for col in df.columns:\n"," print(f\"{col}: {df[col].nunique()} unique values\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dnIWOPPqSTnw"},"outputs":[],"source":["#@markdown Investigate a json file pt 2\n","\n","import json\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","from collections import Counter\n","import numpy as np\n","\n","# Set up plotting style\n","plt.style.use('default')\n","%matplotlib inline\n","\n","# Path to the uploaded .jsonl file\n","#file_path = ''\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","df = pd.DataFrame(data)\n","\n","# 1. Rating Distribution\n","print(\"=== Rating Distribution ===\")\n","rating_counts = df['rating'].value_counts()\n","plt.figure(figsize=(8, 5))\n","sns.barplot(x=rating_counts.index, y=rating_counts.values)\n","plt.title('Distribution of Image Ratings')\n","plt.xlabel('Rating')\n","plt.ylabel('Count')\n","plt.show()\n","print(rating_counts)\n","\n","# 2. Tag Analysis\n","print(\"\\n=== Top 10 Most Common Tags ===\")\n","# Combine all tags into a single list\n","all_tags = []\n","for tags in df['tag_string'].dropna():\n"," all_tags.extend(tags.split())\n","tag_counts = Counter(all_tags)\n","top_tags = pd.DataFrame(tag_counts.most_common(10), columns=['Tag', 'Count'])\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x='Count', y='Tag', data=top_tags)\n","plt.title('Top 10 Most Common Tags')\n","plt.show()\n","print(top_tags)\n","\n","# 3. Image Dimensions Analysis\n","print(\"\\n=== Image Dimensions Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'], df['image_height'], alpha=0.5, s=50)\n","plt.title('Image Width vs. Height')\n","plt.xlabel('Width (pixels)')\n","plt.ylabel('Height (pixels)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Median Width: {df['image_width'].median()}\")\n","print(f\"Median Height: {df['image_height'].median()}\")\n","print(f\"Aspect Ratio (Width/Height) Stats:\\n{df['image_width'].div(df['image_height']).describe()}\")\n","\n","# 4. Score and Voting Analysis\n","print(\"\\n=== Score and Voting Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['score'], bins=30, kde=True)\n","plt.title('Distribution of Image Scores')\n","plt.xlabel('Score')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Score Stats:\\n{df['score'].describe()}\")\n","print(f\"\\nCorrelation between Up Score and Down Score: {df['up_score'].corr(df['down_score'])}\")\n","\n","# 5. Summary Length Analysis\n","print(\"\\n=== Summary Length Analysis ===\")\n","df['summary_length'] = df['regular_summary'].dropna().apply(lambda x: len(str(x).split()))\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['summary_length'], bins=30, kde=True)\n","plt.title('Distribution of Regular Summary Word Counts')\n","plt.xlabel('Word Count')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Summary Length Stats:\\n{df['summary_length'].describe()}\")\n","\n","# 6. Missing Data Heatmap\n","print(\"\\n=== Missing Data Heatmap ===\")\n","plt.figure(figsize=(12, 8))\n","sns.heatmap(df.isnull(), cbar=False, cmap='viridis')\n","plt.title('Missing Data Heatmap')\n","plt.show()\n","\n","# 7. Source Platform Analysis\n","print(\"\\n=== Source Platform Analysis ===\")\n","# Extract domain from source URLs\n","df['source_domain'] = df['source'].dropna().str.extract(r'(https?://[^/]+)')\n","source_counts = df['source_domain'].value_counts().head(10)\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x=source_counts.values, y=source_counts.index)\n","plt.title('Top 10 Source Domains')\n","plt.xlabel('Count')\n","plt.ylabel('Domain')\n","plt.show()\n","print(source_counts)\n","\n","# 8. File Size vs. Image Dimensions\n","print(\"\\n=== File Size vs. Image Dimensions ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'] * df['image_height'], df['file_size'], alpha=0.5)\n","plt.title('File Size vs. Image Area')\n","plt.xlabel('Image Area (Width * Height)')\n","plt.ylabel('File Size (bytes)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Correlation between Image Area and File Size: {df['file_size'].corr(df['image_width'] * df['image_height'])}\")"]},{"cell_type":"code","source":["#@markdown convert E621 JSON to parquet file\n","\n","import json,os\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","df = pd.DataFrame(data)\n","\n","# Define columns that likely contain prompts/image descriptions\n","description_columns = [\n"," 'regular_summary',\n"," 'individual_parts',\n"," 'midjourney_style_summary',\n"," 'deviantart_commission_request',\n"," 'brief_summary'\n","]\n","\n","# Initialize a list to store all descriptions\n","all_descriptions = []\n","\n","# Iterate through each row and collect non-empty descriptions\n","for index, row in df.iterrows():\n"," record_descriptions = []\n"," for col in description_columns:\n"," if pd.notnull(row[col]) and row[col]: # Check for non-null and non-empty values\n"," record_descriptions.append(f\"{col}: {row[col]}\")\n"," if record_descriptions:\n"," all_descriptions.append({\n"," 'id': row['id'],\n"," 'descriptions': '; '.join(record_descriptions) # Join descriptions with semicolon\n"," })\n","\n","# Convert to DataFrame for Parquet\n","output_df = pd.DataFrame(all_descriptions)\n","\n","# Save to Parquet file\n","output_path = '' #@param {type:'string'}\n","output_df.to_parquet(output_path, index=False)\n","os.remove(f'{file_path}')\n","print(f\"\\nDescriptions have been saved to '{output_path}'.\")"],"metadata":{"id":"-NXBRSv4jsUS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Step 1: Mount Google Drive\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","#@markdown paste .parquet file stored on your Google Drive folder to see its characteristics\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Basic exploration of the Parquet file\n","print(\"First 5 rows of the dataset:\")\n","print(df.head())\n","\n","print(\"\\nDataset Info:\")\n","print(df.info())\n","\n","print(\"\\nBasic Statistics:\")\n","print(df.describe())\n","\n","print(\"\\nColumn Names:\")\n","print(df.columns.tolist())\n","\n","print(\"\\nMissing Values:\")\n","print(df.isnull().sum())\n","\n","# Optional: Display number of rows and columns\n","print(f\"\\nShape of the dataset: {df.shape}\")"],"metadata":{"id":"So-PKtbo5AVA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Read contents of a .parquet file\n","\n","# Import pandas\n","import pandas as pd\n","\n","# Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","parquet_column = 'descriptions' #@param {type:'string'}\n","# Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Set pandas display options to show full text without truncation\n","pd.set_option('display.max_colwidth', None) # Show full content of columns\n","pd.set_option('display.width', None) # Use full display width\n","\n","# Create sliders for selecting the range of captions\n","#@markdown Caption Range { run: \"auto\", display_mode: \"form\" }\n","start_at = 16814 #@param {type:\"slider\", min:0, max:33147, step:1}\n","range = 247 #@param {type:'slider',min:1,max:1000,step:1}\n","start_index = start_at\n","end_index = start_at + range\n","###@param {type:\"slider\", min:1, max:33148, step:1}\n","\n","include_either_words = '' #@param {type:'string', placeholder:'item1,item2...'}\n","#display_only = True #@param {type:'boolean'}\n","\n","_include_either_words = ''\n","for include_word in include_either_words.split(','):\n"," if include_word.strip()=='':continue\n"," _include_either_words= include_either_words + include_word.lower()+','+include_word.title() +','\n","#-----#\n","_include_either_words = _include_either_words[:len(_include_either_words)-1]\n","\n","\n","# Ensure end_index is greater than start_index and within bounds\n","if end_index <= start_index:\n"," print(\"Error: End index must be greater than start index.\")\n","elif end_index > len(df):\n"," print(f\"Error: End index cannot exceed {len(df)}. Setting to maximum value.\")\n"," end_index = len(df)\n","elif start_index < 0:\n"," print(\"Error: Start index cannot be negative. Setting to 0.\")\n"," start_index = 0\n","\n","# Display the selected range of captions\n","tmp =''\n","\n","categories= ['regular_summary:',';midjourney_style_summary:', 'individual_parts:']\n","\n","print(f\"\\nDisplaying captions from index {start_index} to {end_index-1}:\")\n","for index, caption in df[f'{parquet_column}'][start_index:end_index].items():\n"," for include_word in _include_either_words.split(','):\n"," found = True\n"," if (include_word.strip() in caption) or include_word.strip()=='':\n"," #----#\n"," if not found: continue\n"," tmp= caption + '\\n\\n'\n"," for category in categories:\n"," tmp = tmp.replace(f'{category}',f'\\n\\n{category}\\n')\n"," #----#\n"," print(f'Index {index}: {tmp}')\n"],"metadata":{"id":"wDhyb8M_7pkD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Build a dataset for training using a .parquet file\n","\n","num_dataset_items = 200 #@param {type:'slider',max:1000}\n","\n","\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import numpy as np\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Randomly select 300 rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items*1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have 200 valid images\n"," break\n"," url = row['url']\n"," caption = row['original_caption'] + ', ' + row['vlm_caption'].replace('This image displays:','').replace('This image displays','')\n","\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly 200 if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","# Optional: Save the dataset to disk (if needed)\n","#dataset.save_to_disk('/content/drive/MyDrive/Chroma prompts/custom_dataset')"],"metadata":{"id":"XZvpJ5zw0fzR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset_name=''#@param {type:'string'}\n","\n","if dataset_name.strip()=='':\n"," dataset_name='my_dataset'\n","\n","\n","dataset.save_to_disk(f'/content/drive/MyDrive/{dataset_name}')\n","\n","\n"],"metadata":{"id":"iTyxazlM1OAn"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"sQmoYDLHUXxF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"jFnWBQHa142R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown load two .parquet datasets for merging\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset1_path = '' #@param {type: 'string'}\n","\n","dataset2_path = '' #@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset1 = load_from_disk(dataset1_path)\n"," dataset2 = load_from_disk(dataset2_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the dataset\n","print(dataset1)\n","print(dataset2)\n","\n","# Step 7: Example of accessing an image and text\n","#print(\"\\nExample of accessing first item:\")\n","#print(\"Text:\", redcaps_dataset['text'][0])\n","#print(\"Image type:\", type(dataset['image'][0]))\n","#print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"LoCcBJqs4pzL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"AmLgPcrdRqCJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"X5HLZqjTRt7L"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Merge the two .parquet files into one\n","\n","# Step 1: Import required libraries\n","from datasets import load_from_disk, concatenate_datasets\n","from google.colab import drive\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","drive.mount('/content/drive')\n","\n","# Step 3: Define paths for the datasets\n","dataset1_path = '' #@param {type:'string'}\n","dataset2_path = '' #@param {type:'string'}\n","merged_dataset_path = '' #@param {type:'string'}\n","\n","# Step 4: Load the datasets\n","try:\n"," dataset1 = load_from_disk(dataset1_path)\n"," dataset2 = load_from_disk(dataset2_path)\n"," print(\"Datasets loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading datasets: {e}\")\n"," raise\n","\n","# Step 5: Verify the datasets\n","print(\"Dataset 1:\", dataset1)\n","print(\"Dataset 2:\", dataset2)\n","\n","# Step 6: Merge the datasets\n","try:\n"," merged_dataset = concatenate_datasets([dataset1, dataset2])\n"," print(\"Datasets merged successfully!\")\n","except Exception as e:\n"," print(f\"Error merging datasets: {e}\")\n"," raise\n","\n","# Step 7: Verify the merged dataset\n","print(\"Merged Dataset:\", merged_dataset)\n","\n","# Step 8: Save the merged dataset to Google Drive\n","try:\n"," merged_dataset.save_to_disk(merged_dataset_path)\n"," print(f\"Merged dataset saved successfully to {merged_dataset_path}\")\n","except Exception as e:\n"," print(f\"Error saving merged dataset: {e}\")\n"," raise\n","\n","# Step 9: Optional - Verify the saved dataset by loading it back\n","try:\n"," loaded_merged_dataset = load_from_disk(merged_dataset_path)\n"," print(\"Saved merged dataset loaded successfully for verification:\")\n"," print(loaded_merged_dataset)\n","except Exception as e:\n"," print(f\"Error loading saved merged dataset: {e}\")\n"," raise"],"metadata":{"id":"HF_cmJu1EMJV"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["🔄 Change to T4 Runtime : Past this point you can train a LoRa on the Dataset , but you need to change the runtime to T4 for that first\n","\n","See original file at:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb"],"metadata":{"id":"0Kmf1OP6Se4Q"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"ESLqweKz4xM_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Test the merged dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = ''#@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset = load_from_disk(dataset_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the dataset\n","print(dataset)\n","\n","# Step 7: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"xUA37h2APkWc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"4hCnrtv6R9B1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"MSetS3MCR2qJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9CBpiISFa6C"},"source":["To format the dataset, all vision fine-tuning tasks should follow this format:\n","\n","```python\n","[\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n","]\n","```"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"oPXzJZzHEgXe"},"outputs":[],"source":["#@markdown Convert the merged dataset to the 'correct' format for training the Gemma LoRa model\n","\n","instruction = \"Describe this image.\" # <- Select the prompt for your use case here\n","\n","def convert_to_conversation(sample):\n"," conversation = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n"," {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": sample[\"text\"]}]},\n"," ]\n"," return {\"messages\": conversation}\n","pass"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gFW2qXIr7Ezy"},"outputs":[],"source":["converted_dataset = [convert_to_conversation(sample) for sample in dataset]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gGFzmplrEy9I"},"outputs":[],"source":["converted_dataset[0]"]},{"cell_type":"markdown","metadata":{"id":"529CsYil1qc6"},"source":["### Installation"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9vJOucOw1qc6"},"outputs":[],"source":["%%capture\n","import os\n","if \"COLAB_\" not in \"\".join(os.environ.keys()):\n"," !pip install unsloth\n","else:\n"," # Do this only in Colab notebooks! Otherwise use pip install unsloth\n"," !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n"," !pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n"," !pip install --no-deps unsloth"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"QmUBVEnvCDJv"},"outputs":[],"source":["from unsloth import FastVisionModel # FastLanguageModel for LLMs\n","import torch\n","\n","# 4bit pre quantized models we support for 4x faster downloading + no OOMs.\n","fourbit_models = [\n"," \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\", # Llama 3.2 vision support\n"," \"unsloth/Llama-3.2-11B-Vision-bnb-4bit\",\n"," \"unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit\", # Can fit in a 80GB card!\n"," \"unsloth/Llama-3.2-90B-Vision-bnb-4bit\",\n","\n"," \"unsloth/Pixtral-12B-2409-bnb-4bit\", # Pixtral fits in 16GB!\n"," \"unsloth/Pixtral-12B-Base-2409-bnb-4bit\", # Pixtral base model\n","\n"," \"unsloth/Qwen2-VL-2B-Instruct-bnb-4bit\", # Qwen2 VL support\n"," \"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit\",\n"," \"unsloth/Qwen2-VL-72B-Instruct-bnb-4bit\",\n","\n"," \"unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit\", # Any Llava variant works!\n"," \"unsloth/llava-1.5-7b-hf-bnb-4bit\",\n","] # More models at https://huggingface.co/unsloth\n","\n","model, processor = FastVisionModel.from_pretrained(\n"," \"unsloth/gemma-3-4b-pt\",\n"," load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.\n"," use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for long context\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bEzvL7Sm1CrS"},"outputs":[],"source":["from unsloth import get_chat_template\n","\n","processor = get_chat_template(\n"," processor,\n"," \"gemma-3\"\n",")"]},{"cell_type":"markdown","metadata":{"id":"SXd9bTZd1aaL"},"source":["We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.\n","\n","**[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6bZsfBuZDeCL"},"outputs":[],"source":["model = FastVisionModel.get_peft_model(\n"," model,\n"," finetune_vision_layers = True, # False if not finetuning vision layers\n"," finetune_language_layers = True, # False if not finetuning language layers\n"," finetune_attention_modules = True, # False if not finetuning attention layers\n"," finetune_mlp_modules = True, # False if not finetuning MLP layers\n","\n"," r = 16, # The larger, the higher the accuracy, but might overfit\n"," lora_alpha = 16, # Recommended alpha == r at least\n"," lora_dropout = 0,\n"," bias = \"none\",\n"," random_state = 3408,\n"," use_rslora = False, # We support rank stabilized LoRA\n"," loftq_config = None, # And LoftQ\n"," target_modules = \"all-linear\", # Optional now! Can specify a list if needed\n"," modules_to_save=[\n"," \"lm_head\",\n"," \"embed_tokens\",\n"," ],\n",")"]},{"cell_type":"markdown","metadata":{"id":"FecKS-dA82f5"},"source":["Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vcat4UxA81vr"},"outputs":[],"source":["FastVisionModel.for_inference(model) # Enable for inference!\n","\n","image = dataset[2][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"idAEIeSQ3xdS"},"source":["\n","### Train the model\n","Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!\n","\n","We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"95_Nn-89DhsL"},"outputs":[],"source":["from unsloth.trainer import UnslothVisionDataCollator\n","from trl import SFTTrainer, SFTConfig\n","\n","FastVisionModel.for_training(model) # Enable for training!\n","\n","trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=converted_dataset,\n"," processing_class=processor.tokenizer,\n"," data_collator=UnslothVisionDataCollator(model, processor),\n"," args = SFTConfig(\n"," per_device_train_batch_size = 1,\n"," gradient_accumulation_steps = 4,\n"," gradient_checkpointing = True,\n","\n"," # use reentrant checkpointing\n"," gradient_checkpointing_kwargs = {\"use_reentrant\": False},\n"," max_grad_norm = 0.3, # max gradient norm based on QLoRA paper\n"," warmup_ratio = 0.03,\n"," #max_steps = 30,\n"," num_train_epochs = 5, # Set this instead of max_steps for full training runs\n"," learning_rate = 2e-4,\n"," logging_steps = 1,\n"," save_strategy=\"steps\",\n"," optim = \"adamw_torch_fused\",\n"," weight_decay = 0.01,\n"," lr_scheduler_type = \"cosine\",\n"," seed = 3407,\n"," output_dir = \"outputs\",\n"," report_to = \"none\", # For Weights and Biases\n","\n"," # You MUST put the below items for vision finetuning:\n"," remove_unused_columns = False,\n"," dataset_text_field = \"\",\n"," dataset_kwargs = {\"skip_prepare_dataset\": True},\n"," max_length = 2048,\n"," )\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"2ejIt2xSNKKp"},"outputs":[],"source":["# @title Show current memory stats\n","gpu_stats = torch.cuda.get_device_properties(0)\n","start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n","print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n","print(f\"{start_gpu_memory} GB of memory reserved.\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yqxqAZ7KJ4oL"},"outputs":[],"source":["trainer_stats = trainer.train()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"pCqnaKmlO1U9"},"outputs":[],"source":["# @title Show final memory and time stats\n","used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n","used_percentage = round(used_memory / max_memory * 100, 3)\n","lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n","print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n","print(\n"," f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",")\n","print(f\"Peak reserved memory = {used_memory} GB.\")\n","print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n","print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n","print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"]},{"cell_type":"markdown","metadata":{"id":"ekOmTR1hSNcr"},"source":["\n","### Inference\n","Let's run the model! You can modify the instruction and input—just leave the output blank.\n","\n","We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kR3gIAX-SM2q"},"outputs":[],"source":["FastVisionModel.for_inference(model) # Enable for inference!\n","\n","image = dataset[10][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"code","source":["# Step 1: Import required libraries\n","from PIL import Image\n","import io\n","import torch\n","from google.colab import files # For file upload in Colab\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","# Step 3: Upload image from user\n","print(\"Please upload an image file (e.g., .jpg, .png):\")\n","uploaded = files.upload() # Opens a file upload widget in Colab\n","\n","# Step 4: Load the uploaded image\n","if not uploaded:\n"," raise ValueError(\"No file uploaded. Please upload an image.\")\n","\n","# Get the first uploaded file\n","file_name = list(uploaded.keys())[0]\n","try:\n"," image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')\n","except Exception as e:\n"," raise ValueError(f\"Error loading image: {e}\")\n","\n","# Step 5: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 6: Prepare messages for the model\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","# Step 7: Apply chat template and prepare inputs\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","# Step 8: Generate output with text streaming\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(\n"," **inputs,\n"," streamer=text_streamer,\n"," max_new_tokens=512,\n"," use_cache=True,\n"," temperature=1.0,\n"," top_p=0.95,\n"," top_k=64\n",")"],"metadata":{"id":"oOyy5FUh8fBi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uMuVrWbjAzhc"},"source":["\n","### Saving, loading finetuned models\n","To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.\n","\n","**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"upcOlWe7A1vc"},"outputs":[],"source":["model.save_pretrained(\"lora_model\") # Local saving\n","processor.save_pretrained(\"lora_model\")\n","# model.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving\n","# processor.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving"]},{"cell_type":"markdown","metadata":{"id":"AEEcJ4qfC7Lp"},"source":["Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MKX_XKs_BNZR"},"outputs":[],"source":["if False:\n"," from unsloth import FastVisionModel\n","\n"," model, processor = FastVisionModel.from_pretrained(\n"," model_name=\"lora_model\", # YOUR MODEL YOU USED FOR TRAINING\n"," load_in_4bit=True, # Set to False for 16bit LoRA\n"," )\n"," FastVisionModel.for_inference(model) # Enable for inference!\n","\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","sample = dataset[1]\n","image = sample[\"image\"].convert(\"RGB\")\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\n"," \"type\": \"text\",\n"," \"text\": sample[\"text\"],\n"," },\n"," {\n"," \"type\": \"image\",\n"," },\n"," ],\n"," },\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor.tokenizer, skip_prompt=True)\n","_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"f422JgM9sdVT"},"source":["### Saving to float16 for VLLM\n","\n","We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."]}]}
\ No newline at end of file
+{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754519491020},{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/parquet_explorer.ipynb","timestamp":1754497857381},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754475181338},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754312448728},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754310418707},{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1754223895158},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1747490904984},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1740037333374},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1736477078136},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1725365086834}]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"2ee3770bdd084ea5a006437ef64c278a":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_3286274920f142309915f5b42052011b","IPY_MODEL_79c5530314664494994a62df7eb1ee92","IPY_MODEL_7fb478f298b14a6a9314708d3750aa4e"],"layout":"IPY_MODEL_3a0e3893879c4cdcb7074a16ec4a05d0"}},"3286274920f142309915f5b42052011b":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_1cad313748734432a879e4f61bd04a9d","placeholder":"​","style":"IPY_MODEL_92a9c8b06f4b49a4b9a5d9977987d487","value":"Saving the dataset (1/1 shards): 100%"}},"79c5530314664494994a62df7eb1ee92":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_91d9082686944b6ea53c57c3d271c112","max":200,"min":0,"orientation":"horizontal","style":"IPY_MODEL_e7e636621f2349cbb77286d40e9281dd","value":200}},"7fb478f298b14a6a9314708d3750aa4e":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_a3268d9c69344d7297d6efb49e024a3e","placeholder":"​","style":"IPY_MODEL_d5658c563f83405e97d6611a115ab22b","value":" 200/200 [00:00<00:00, 270.83 examples/s]"}},"3a0e3893879c4cdcb7074a16ec4a05d0":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"1cad313748734432a879e4f61bd04a9d":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"92a9c8b06f4b49a4b9a5d9977987d487":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"91d9082686944b6ea53c57c3d271c112":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"e7e636621f2349cbb77286d40e9281dd":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"a3268d9c69344d7297d6efb49e024a3e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d5658c563f83405e97d6611a115ab22b":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":["Download a parquet file to your Google drive and load it from there into this notebook.\n","\n","Parquet files: https://huggingface.co/datasets/codeShare/chroma_prompts/tree/main\n","\n","E621 JSON files: https://huggingface.co/datasets/lodestones/e621-captions/tree/main"],"metadata":{"id":"LeCfcqgiQvCP"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"HFy5aDxM3G7O","executionInfo":{"status":"ok","timestamp":1754519627735,"user_tz":-120,"elapsed":103567,"user":{"displayName":"No Name","userId":"10578412414437288386"}},"outputId":"261502b8-4439-45db-be97-a0ffa47b28ab","colab":{"base_uri":"https://localhost:8080/"}},"execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"code","source":["#@markdown Build a dataset for training using a .jsonl file\n","\n","num_dataset_items = 800 #@param {type:'slider', max:10000}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import json\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import math,random\n","\n","# Step 3: Define the path to the JSONL file\n","file_path = '/content/drive/MyDrive/Saved from Chrome/2022-08_grouped.jsonl' #@param {type:'string'}\n","\n","# Step 4: Read the JSONL file\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","\n","# Convert to DataFrame\n","df = pd.DataFrame(data)\n","\n","# Step 5: Randomly select rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items * 1.1), random_state=math.floor(random.random()*10000)).reset_index(drop=True)\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," #num=num+1\n"," #print(f\"{num}\")\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","num=1\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have enough valid images\n"," break\n"," url = row['url']\n"," # Combine description and tag_string for caption, ensuring no missing values\n"," description = row['description'] if pd.notnull(row['description']) else ''\n"," tag_string = row['tag_string'] if pd.notnull(row['tag_string']) else ''\n"," caption = f\"{description}, {tag_string}\".strip(', ')\n","\n"," num=num+1\n"," print(f'{num}')\n","\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly num_dataset_items if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","output_name='dataset3'#@param {type:'string'}\n","# Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk(f'/content/{output_name}')"],"metadata":{"id":"jtPz3voOhnBj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset=''"],"metadata":{"id":"lOkICBHuuGAQ","executionInfo":{"status":"ok","timestamp":1754523104327,"user_tz":-120,"elapsed":4,"user":{"displayName":"No Name","userId":"10578412414437288386"}}},"execution_count":13,"outputs":[]},{"cell_type":"code","source":["#@markdown Merge the two datasets into one\n","\n","# Step 1: Import required libraries\n","from datasets import load_from_disk, concatenate_datasets\n","from google.colab import drive\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","drive.mount('/content/drive')\n","\n","# Step 3: Define paths for the datasets\n","dataset1_path = '/content/dataset12' #@param {type:'string'}\n","dataset2_path = '/content/dataset3' #@param {type:'string'}\n","merged_dataset_path = '/content/drive/MyDrive/dataset_e621' #@param {type:'string'}\n","\n","# Step 4: Load the datasets\n","try:\n"," dataset1 = load_from_disk(dataset1_path)\n"," dataset2 = load_from_disk(dataset2_path)\n"," print(\"Datasets loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading datasets: {e}\")\n"," raise\n","\n","# Step 5: Verify the datasets\n","print(\"Dataset 1:\", dataset1)\n","print(\"Dataset 2:\", dataset2)\n","\n","# Step 6: Merge the datasets\n","try:\n"," merged_dataset = concatenate_datasets([dataset1, dataset2])\n"," print(\"Datasets merged successfully!\")\n","except Exception as e:\n"," print(f\"Error merging datasets: {e}\")\n"," raise\n","\n","# Step 7: Verify the merged dataset\n","print(\"Merged Dataset:\", merged_dataset)\n","\n","# Step 8: Save the merged dataset to Google Drive\n","try:\n"," merged_dataset.save_to_disk(merged_dataset_path)\n"," print(f\"Merged dataset saved successfully to {merged_dataset_path}\")\n","except Exception as e:\n"," print(f\"Error saving merged dataset: {e}\")\n"," raise\n","\n","# Step 9: Optional - Verify the saved dataset by loading it back\n","try:\n"," loaded_merged_dataset = load_from_disk(merged_dataset_path)\n"," print(\"Saved merged dataset loaded successfully for verification:\")\n"," print(loaded_merged_dataset)\n","except Exception as e:\n"," print(f\"Error loading saved merged dataset: {e}\")\n"," raise"],"metadata":{"id":"HF_cmJu1EMJV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown load two datasets for merging\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset1_path = '' #@param {type: 'string'}\n","\n","dataset2_path = '' #@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset1 = load_from_disk(dataset1_path)\n"," dataset2 = load_from_disk(dataset2_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the dataset\n","print(dataset1)\n","print(dataset2)\n","\n","# Step 7: Example of accessing an image and text\n","#print(\"\\nExample of accessing first item:\")\n","#print(\"Text:\", redcaps_dataset['text'][0])\n","#print(\"Image type:\", type(dataset['image'][0]))\n","#print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"LoCcBJqs4pzL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset.save_to_disk(f'/content/drive/MyDrive/{output_name}')\n","\n","\n","\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":77,"referenced_widgets":["2ee3770bdd084ea5a006437ef64c278a","3286274920f142309915f5b42052011b","79c5530314664494994a62df7eb1ee92","7fb478f298b14a6a9314708d3750aa4e","3a0e3893879c4cdcb7074a16ec4a05d0","1cad313748734432a879e4f61bd04a9d","92a9c8b06f4b49a4b9a5d9977987d487","91d9082686944b6ea53c57c3d271c112","e7e636621f2349cbb77286d40e9281dd","a3268d9c69344d7297d6efb49e024a3e","d5658c563f83405e97d6611a115ab22b"]},"id":"V2o9DjTNjIzr","executionInfo":{"status":"ok","timestamp":1754520201990,"user_tz":-120,"elapsed":823,"user":{"displayName":"No Name","userId":"10578412414437288386"}},"outputId":"10645556-2593-46f0-e884-14f8f48e8c75"},"execution_count":4,"outputs":[{"output_type":"display_data","data":{"text/plain":["Saving the dataset (0/1 shards): 0%| | 0/200 [00:00, ? examples/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"2ee3770bdd084ea5a006437ef64c278a"}},"metadata":{}}]},{"cell_type":"code","source":["\n","#@markdown Build a dataset for training using a .parquet file\n","\n","num_dataset_items = 200 #@param {type:'slider',max:1000}\n","\n","outout_name='dataset1.parquet'#@param {type:'string'}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import numpy as np\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '/content/drive/MyDrive/Saved from Chrome/2022-08_grouped.jsonl' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Randomly select 300 rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items*1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have 200 valid images\n"," break\n"," url = row['url']\n"," caption = row['original_caption'] + ', ' + row['vlm_caption'].replace('This image displays:','').replace('This image displays','')\n","\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly 200 if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","#Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk(f'/content/drive/MyDrive/{output_name}')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":711},"executionInfo":{"status":"error","timestamp":1754519693014,"user_tz":-120,"elapsed":3171,"user":{"displayName":"No Name","userId":"10578412414437288386"}},"outputId":"c151467d-b57d-4236-f6da-1cb04c4def0a","id":"ENA-zhQHhXcV"},"execution_count":2,"outputs":[{"output_type":"error","ename":"ArrowInvalid","evalue":"Could not open Parquet input source '': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mArrowInvalid\u001b[0m Traceback (most recent call last)","\u001b[0;32m/tmp/ipython-input-2772223238.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;31m# Step 4: Read the Parquet file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_parquet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# Step 5: Randomly select 300 rows to account for potential image loading failures\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pandas/io/parquet.py\u001b[0m in \u001b[0;36mread_parquet\u001b[0;34m(path, engine, columns, storage_options, use_nullable_dtypes, dtype_backend, filesystem, filters, **kwargs)\u001b[0m\n\u001b[1;32m 665\u001b[0m \u001b[0mcheck_dtype_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype_backend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 667\u001b[0;31m return impl.read(\n\u001b[0m\u001b[1;32m 668\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pandas/io/parquet.py\u001b[0m in \u001b[0;36mread\u001b[0;34m(self, path, columns, filters, use_nullable_dtypes, dtype_backend, storage_options, filesystem, **kwargs)\u001b[0m\n\u001b[1;32m 272\u001b[0m )\n\u001b[1;32m 273\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m pa_table = self.api.parquet.read_table(\n\u001b[0m\u001b[1;32m 275\u001b[0m \u001b[0mpath_or_handle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pyarrow/parquet/core.py\u001b[0m in \u001b[0;36mread_table\u001b[0;34m(source, columns, use_threads, schema, use_pandas_metadata, read_dictionary, memory_map, buffer_size, partitioning, filesystem, filters, use_legacy_dataset, ignore_prefixes, pre_buffer, coerce_int96_timestamp_unit, decryption_properties, thrift_string_size_limit, thrift_container_size_limit, page_checksum_verification)\u001b[0m\n\u001b[1;32m 1791\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1792\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1793\u001b[0;31m dataset = ParquetDataset(\n\u001b[0m\u001b[1;32m 1794\u001b[0m \u001b[0msource\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1795\u001b[0m \u001b[0mschema\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mschema\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pyarrow/parquet/core.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, path_or_paths, filesystem, schema, filters, read_dictionary, memory_map, buffer_size, partitioning, ignore_prefixes, pre_buffer, coerce_int96_timestamp_unit, decryption_properties, thrift_string_size_limit, thrift_container_size_limit, page_checksum_verification, use_legacy_dataset)\u001b[0m\n\u001b[1;32m 1358\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1359\u001b[0m self._dataset = ds.FileSystemDataset(\n\u001b[0;32m-> 1360\u001b[0;31m \u001b[0;34m[\u001b[0m\u001b[0mfragment\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mschema\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mschema\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfragment\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mphysical_schema\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1361\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparquet_format\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1362\u001b[0m \u001b[0mfilesystem\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfragment\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilesystem\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pyarrow/_dataset.pyx\u001b[0m in \u001b[0;36mpyarrow._dataset.Fragment.physical_schema.__get__\u001b[0;34m()\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pyarrow/error.pxi\u001b[0m in \u001b[0;36mpyarrow.lib.pyarrow_internal_check_status\u001b[0;34m()\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pyarrow/error.pxi\u001b[0m in \u001b[0;36mpyarrow.lib.check_status\u001b[0;34m()\u001b[0m\n","\u001b[0;31mArrowInvalid\u001b[0m: Could not open Parquet input source '': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file."]}]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KYv7Y2gNPW_i"},"outputs":[],"source":["#@markdown Investigate a json file\n","\n","import json\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Initialize lists to store data\n","data = []\n","\n","# Read the .jsonl file line by line\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," # Parse each line as a JSON object\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","\n","# Convert the list of JSON objects to a Pandas DataFrame for easier exploration\n","df = pd.DataFrame(data)\n","\n","# Display basic information about the DataFrame\n","print(\"=== File Overview ===\")\n","print(f\"Number of records: {len(df)}\")\n","print(\"\\nColumn names:\")\n","print(df.columns.tolist())\n","print(\"\\nData types:\")\n","print(df.dtypes)\n","\n","# Display the first few rows\n","print(\"\\n=== First 5 Rows ===\")\n","print(df.head())\n","\n","# Display basic statistics\n","print(\"\\n=== Basic Statistics ===\")\n","print(df.describe(include='all'))\n","\n","# Check for missing values\n","print(\"\\n=== Missing Values ===\")\n","print(df.isnull().sum())\n","\n","# Optional: Display unique values in each column\n","print(\"\\n=== Unique Values per Column ===\")\n","for col in df.columns:\n"," print(f\"{col}: {df[col].nunique()} unique values\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dnIWOPPqSTnw"},"outputs":[],"source":["#@markdown Investigate a json file pt 2\n","\n","import json\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","from collections import Counter\n","import numpy as np\n","\n","# Set up plotting style\n","plt.style.use('default')\n","%matplotlib inline\n","\n","# Path to the uploaded .jsonl file\n","#file_path = ''\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","df = pd.DataFrame(data)\n","\n","# 1. Rating Distribution\n","print(\"=== Rating Distribution ===\")\n","rating_counts = df['rating'].value_counts()\n","plt.figure(figsize=(8, 5))\n","sns.barplot(x=rating_counts.index, y=rating_counts.values)\n","plt.title('Distribution of Image Ratings')\n","plt.xlabel('Rating')\n","plt.ylabel('Count')\n","plt.show()\n","print(rating_counts)\n","\n","# 2. Tag Analysis\n","print(\"\\n=== Top 10 Most Common Tags ===\")\n","# Combine all tags into a single list\n","all_tags = []\n","for tags in df['tag_string'].dropna():\n"," all_tags.extend(tags.split())\n","tag_counts = Counter(all_tags)\n","top_tags = pd.DataFrame(tag_counts.most_common(10), columns=['Tag', 'Count'])\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x='Count', y='Tag', data=top_tags)\n","plt.title('Top 10 Most Common Tags')\n","plt.show()\n","print(top_tags)\n","\n","# 3. Image Dimensions Analysis\n","print(\"\\n=== Image Dimensions Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'], df['image_height'], alpha=0.5, s=50)\n","plt.title('Image Width vs. Height')\n","plt.xlabel('Width (pixels)')\n","plt.ylabel('Height (pixels)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Median Width: {df['image_width'].median()}\")\n","print(f\"Median Height: {df['image_height'].median()}\")\n","print(f\"Aspect Ratio (Width/Height) Stats:\\n{df['image_width'].div(df['image_height']).describe()}\")\n","\n","# 4. Score and Voting Analysis\n","print(\"\\n=== Score and Voting Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['score'], bins=30, kde=True)\n","plt.title('Distribution of Image Scores')\n","plt.xlabel('Score')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Score Stats:\\n{df['score'].describe()}\")\n","print(f\"\\nCorrelation between Up Score and Down Score: {df['up_score'].corr(df['down_score'])}\")\n","\n","# 5. Summary Length Analysis\n","print(\"\\n=== Summary Length Analysis ===\")\n","df['summary_length'] = df['regular_summary'].dropna().apply(lambda x: len(str(x).split()))\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['summary_length'], bins=30, kde=True)\n","plt.title('Distribution of Regular Summary Word Counts')\n","plt.xlabel('Word Count')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Summary Length Stats:\\n{df['summary_length'].describe()}\")\n","\n","# 6. Missing Data Heatmap\n","print(\"\\n=== Missing Data Heatmap ===\")\n","plt.figure(figsize=(12, 8))\n","sns.heatmap(df.isnull(), cbar=False, cmap='viridis')\n","plt.title('Missing Data Heatmap')\n","plt.show()\n","\n","# 7. Source Platform Analysis\n","print(\"\\n=== Source Platform Analysis ===\")\n","# Extract domain from source URLs\n","df['source_domain'] = df['source'].dropna().str.extract(r'(https?://[^/]+)')\n","source_counts = df['source_domain'].value_counts().head(10)\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x=source_counts.values, y=source_counts.index)\n","plt.title('Top 10 Source Domains')\n","plt.xlabel('Count')\n","plt.ylabel('Domain')\n","plt.show()\n","print(source_counts)\n","\n","# 8. File Size vs. Image Dimensions\n","print(\"\\n=== File Size vs. Image Dimensions ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'] * df['image_height'], df['file_size'], alpha=0.5)\n","plt.title('File Size vs. Image Area')\n","plt.xlabel('Image Area (Width * Height)')\n","plt.ylabel('File Size (bytes)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Correlation between Image Area and File Size: {df['file_size'].corr(df['image_width'] * df['image_height'])}\")"]},{"cell_type":"code","source":["#@markdown convert E621 JSON to parquet file\n","\n","import json,os\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","df = pd.DataFrame(data)\n","\n","# Define columns that likely contain prompts/image descriptions\n","description_columns = [\n"," 'regular_summary',\n"," 'individual_parts',\n"," 'midjourney_style_summary',\n"," 'deviantart_commission_request',\n"," 'brief_summary'\n","]\n","\n","# Initialize a list to store all descriptions\n","all_descriptions = []\n","\n","# Iterate through each row and collect non-empty descriptions\n","for index, row in df.iterrows():\n"," record_descriptions = []\n"," for col in description_columns:\n"," if pd.notnull(row[col]) and row[col]: # Check for non-null and non-empty values\n"," record_descriptions.append(f\"{col}: {row[col]}\")\n"," if record_descriptions:\n"," all_descriptions.append({\n"," 'id': row['id'],\n"," 'descriptions': '; '.join(record_descriptions) # Join descriptions with semicolon\n"," })\n","\n","# Convert to DataFrame for Parquet\n","output_df = pd.DataFrame(all_descriptions)\n","\n","# Save to Parquet file\n","output_path = '' #@param {type:'string'}\n","output_df.to_parquet(output_path, index=False)\n","os.remove(f'{file_path}')\n","print(f\"\\nDescriptions have been saved to '{output_path}'.\")"],"metadata":{"id":"-NXBRSv4jsUS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Build a dataset for training using a .jsonl file\n","\n","num_dataset_items = 200 #@param {type:'slider', max:10000}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import json\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import math\n","\n","# Step 3: Define the path to the JSONL file\n","file_path = '/kaggle/input/image-caption-dataset/e621_2022_02.jsonl' #@param {type:'string'}\n","\n","# Step 4: Read the JSONL file\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","\n","# Convert to DataFrame\n","df = pd.DataFrame(data)\n","\n","# Step 5: Randomly select rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items * 1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have enough valid images\n"," break\n"," url = row['url']\n"," # Combine description and tag_string for caption, ensuring no missing values\n"," description = row['description'] if pd.notnull(row['description']) else ''\n"," tag_string = row['tag_string'] if pd.notnull(row['tag_string']) else ''\n"," caption = f\"{description}, {tag_string}\".strip(', ')\n","\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly num_dataset_items if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","# Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk('/kaggle/output/custom_dataset')"],"metadata":{"id":"aAfdBkw_fNv0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"WcBCR1eBfXWs"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Step 1: Mount Google Drive\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","#@markdown paste .parquet file stored on your Google Drive folder to see its characteristics\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Basic exploration of the Parquet file\n","print(\"First 5 rows of the dataset:\")\n","print(df.head())\n","\n","print(\"\\nDataset Info:\")\n","print(df.info())\n","\n","print(\"\\nBasic Statistics:\")\n","print(df.describe())\n","\n","print(\"\\nColumn Names:\")\n","print(df.columns.tolist())\n","\n","print(\"\\nMissing Values:\")\n","print(df.isnull().sum())\n","\n","# Optional: Display number of rows and columns\n","print(f\"\\nShape of the dataset: {df.shape}\")"],"metadata":{"id":"So-PKtbo5AVA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Read contents of a .parquet file\n","\n","# Import pandas\n","import pandas as pd\n","\n","# Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","parquet_column = 'descriptions' #@param {type:'string'}\n","# Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Set pandas display options to show full text without truncation\n","pd.set_option('display.max_colwidth', None) # Show full content of columns\n","pd.set_option('display.width', None) # Use full display width\n","\n","# Create sliders for selecting the range of captions\n","#@markdown Caption Range { run: \"auto\", display_mode: \"form\" }\n","start_at = 16814 #@param {type:\"slider\", min:0, max:33147, step:1}\n","range = 247 #@param {type:'slider',min:1,max:1000,step:1}\n","start_index = start_at\n","end_index = start_at + range\n","###@param {type:\"slider\", min:1, max:33148, step:1}\n","\n","include_either_words = '' #@param {type:'string', placeholder:'item1,item2...'}\n","#display_only = True #@param {type:'boolean'}\n","\n","_include_either_words = ''\n","for include_word in include_either_words.split(','):\n"," if include_word.strip()=='':continue\n"," _include_either_words= include_either_words + include_word.lower()+','+include_word.title() +','\n","#-----#\n","_include_either_words = _include_either_words[:len(_include_either_words)-1]\n","\n","\n","# Ensure end_index is greater than start_index and within bounds\n","if end_index <= start_index:\n"," print(\"Error: End index must be greater than start index.\")\n","elif end_index > len(df):\n"," print(f\"Error: End index cannot exceed {len(df)}. Setting to maximum value.\")\n"," end_index = len(df)\n","elif start_index < 0:\n"," print(\"Error: Start index cannot be negative. Setting to 0.\")\n"," start_index = 0\n","\n","# Display the selected range of captions\n","tmp =''\n","\n","categories= ['regular_summary:',';midjourney_style_summary:', 'individual_parts:']\n","\n","print(f\"\\nDisplaying captions from index {start_index} to {end_index-1}:\")\n","for index, caption in df[f'{parquet_column}'][start_index:end_index].items():\n"," for include_word in _include_either_words.split(','):\n"," found = True\n"," if (include_word.strip() in caption) or include_word.strip()=='':\n"," #----#\n"," if not found: continue\n"," tmp= caption + '\\n\\n'\n"," for category in categories:\n"," tmp = tmp.replace(f'{category}',f'\\n\\n{category}\\n')\n"," #----#\n"," print(f'Index {index}: {tmp}')\n"],"metadata":{"id":"wDhyb8M_7pkD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Build a dataset for training using a .parquet file\n","\n","num_dataset_items = 200 #@param {type:'slider',max:1000}\n","\n","outout_name='dataset1'#@param {type:'string'}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import numpy as np\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '/content/drive/MyDrive/dataset1.parquet' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Randomly select 300 rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items*1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have 200 valid images\n"," break\n"," url = row['url']\n"," caption = row['original_caption'] + ', ' + row['vlm_caption'].replace('This image displays:','').replace('This image displays','')\n","\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly 200 if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","#Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk(f'/content/drive/MyDrive/{output_name}')"],"metadata":{"id":"XZvpJ5zw0fzR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset_name=''#@param {type:'string'}\n","\n","if dataset_name.strip()=='':\n"," dataset_name='my_dataset'\n","\n","\n","dataset.save_to_disk(f'/content/drive/MyDrive/{dataset_name}')\n","\n","\n"],"metadata":{"id":"iTyxazlM1OAn"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"sQmoYDLHUXxF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"jFnWBQHa142R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"AmLgPcrdRqCJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"X5HLZqjTRt7L"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["🔄 Change to T4 Runtime : Past this point you can train a LoRa on the Dataset , but you need to change the runtime to T4 for that first\n","\n","See original file at:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb"],"metadata":{"id":"0Kmf1OP6Se4Q"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"ESLqweKz4xM_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Test the merged dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = ''#@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset = load_from_disk(dataset_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the dataset\n","print(dataset)\n","\n","# Step 7: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"xUA37h2APkWc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"4hCnrtv6R9B1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"MSetS3MCR2qJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9CBpiISFa6C"},"source":["To format the dataset, all vision fine-tuning tasks should follow this format:\n","\n","```python\n","[\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n","]\n","```"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"oPXzJZzHEgXe"},"outputs":[],"source":["#@markdown Convert the merged dataset to the 'correct' format for training the Gemma LoRa model\n","\n","instruction = \"Describe this image.\" # <- Select the prompt for your use case here\n","\n","def convert_to_conversation(sample):\n"," conversation = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n"," {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": sample[\"text\"]}]},\n"," ]\n"," return {\"messages\": conversation}\n","pass"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gFW2qXIr7Ezy"},"outputs":[],"source":["converted_dataset = [convert_to_conversation(sample) for sample in dataset]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gGFzmplrEy9I"},"outputs":[],"source":["converted_dataset[0]"]},{"cell_type":"markdown","metadata":{"id":"529CsYil1qc6"},"source":["### Installation"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9vJOucOw1qc6"},"outputs":[],"source":["%%capture\n","import os\n","if \"COLAB_\" not in \"\".join(os.environ.keys()):\n"," !pip install unsloth\n","else:\n"," # Do this only in Colab notebooks! Otherwise use pip install unsloth\n"," !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n"," !pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n"," !pip install --no-deps unsloth"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"QmUBVEnvCDJv"},"outputs":[],"source":["from unsloth import FastVisionModel # FastLanguageModel for LLMs\n","import torch\n","\n","# 4bit pre quantized models we support for 4x faster downloading + no OOMs.\n","fourbit_models = [\n"," \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\", # Llama 3.2 vision support\n"," \"unsloth/Llama-3.2-11B-Vision-bnb-4bit\",\n"," \"unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit\", # Can fit in a 80GB card!\n"," \"unsloth/Llama-3.2-90B-Vision-bnb-4bit\",\n","\n"," \"unsloth/Pixtral-12B-2409-bnb-4bit\", # Pixtral fits in 16GB!\n"," \"unsloth/Pixtral-12B-Base-2409-bnb-4bit\", # Pixtral base model\n","\n"," \"unsloth/Qwen2-VL-2B-Instruct-bnb-4bit\", # Qwen2 VL support\n"," \"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit\",\n"," \"unsloth/Qwen2-VL-72B-Instruct-bnb-4bit\",\n","\n"," \"unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit\", # Any Llava variant works!\n"," \"unsloth/llava-1.5-7b-hf-bnb-4bit\",\n","] # More models at https://huggingface.co/unsloth\n","\n","model, processor = FastVisionModel.from_pretrained(\n"," \"unsloth/gemma-3-4b-pt\",\n"," load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.\n"," use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for long context\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bEzvL7Sm1CrS"},"outputs":[],"source":["from unsloth import get_chat_template\n","\n","processor = get_chat_template(\n"," processor,\n"," \"gemma-3\"\n",")"]},{"cell_type":"markdown","metadata":{"id":"SXd9bTZd1aaL"},"source":["We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.\n","\n","**[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6bZsfBuZDeCL"},"outputs":[],"source":["model = FastVisionModel.get_peft_model(\n"," model,\n"," finetune_vision_layers = True, # False if not finetuning vision layers\n"," finetune_language_layers = True, # False if not finetuning language layers\n"," finetune_attention_modules = True, # False if not finetuning attention layers\n"," finetune_mlp_modules = True, # False if not finetuning MLP layers\n","\n"," r = 16, # The larger, the higher the accuracy, but might overfit\n"," lora_alpha = 16, # Recommended alpha == r at least\n"," lora_dropout = 0,\n"," bias = \"none\",\n"," random_state = 3408,\n"," use_rslora = False, # We support rank stabilized LoRA\n"," loftq_config = None, # And LoftQ\n"," target_modules = \"all-linear\", # Optional now! Can specify a list if needed\n"," modules_to_save=[\n"," \"lm_head\",\n"," \"embed_tokens\",\n"," ],\n",")"]},{"cell_type":"markdown","metadata":{"id":"FecKS-dA82f5"},"source":["Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vcat4UxA81vr"},"outputs":[],"source":["FastVisionModel.for_inference(model) # Enable for inference!\n","\n","image = dataset[2][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"idAEIeSQ3xdS"},"source":["\n","### Train the model\n","Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!\n","\n","We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"95_Nn-89DhsL"},"outputs":[],"source":["from unsloth.trainer import UnslothVisionDataCollator\n","from trl import SFTTrainer, SFTConfig\n","\n","FastVisionModel.for_training(model) # Enable for training!\n","\n","trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=converted_dataset,\n"," processing_class=processor.tokenizer,\n"," data_collator=UnslothVisionDataCollator(model, processor),\n"," args = SFTConfig(\n"," per_device_train_batch_size = 1,\n"," gradient_accumulation_steps = 4,\n"," gradient_checkpointing = True,\n","\n"," # use reentrant checkpointing\n"," gradient_checkpointing_kwargs = {\"use_reentrant\": False},\n"," max_grad_norm = 0.3, # max gradient norm based on QLoRA paper\n"," warmup_ratio = 0.03,\n"," #max_steps = 30,\n"," num_train_epochs = 5, # Set this instead of max_steps for full training runs\n"," learning_rate = 2e-4,\n"," logging_steps = 1,\n"," save_strategy=\"steps\",\n"," optim = \"adamw_torch_fused\",\n"," weight_decay = 0.01,\n"," lr_scheduler_type = \"cosine\",\n"," seed = 3407,\n"," output_dir = \"outputs\",\n"," report_to = \"none\", # For Weights and Biases\n","\n"," # You MUST put the below items for vision finetuning:\n"," remove_unused_columns = False,\n"," dataset_text_field = \"\",\n"," dataset_kwargs = {\"skip_prepare_dataset\": True},\n"," max_length = 2048,\n"," )\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"2ejIt2xSNKKp"},"outputs":[],"source":["# @title Show current memory stats\n","gpu_stats = torch.cuda.get_device_properties(0)\n","start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n","print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n","print(f\"{start_gpu_memory} GB of memory reserved.\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yqxqAZ7KJ4oL"},"outputs":[],"source":["trainer_stats = trainer.train()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"pCqnaKmlO1U9"},"outputs":[],"source":["# @title Show final memory and time stats\n","used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n","used_percentage = round(used_memory / max_memory * 100, 3)\n","lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n","print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n","print(\n"," f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",")\n","print(f\"Peak reserved memory = {used_memory} GB.\")\n","print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n","print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n","print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"]},{"cell_type":"markdown","metadata":{"id":"ekOmTR1hSNcr"},"source":["\n","### Inference\n","Let's run the model! You can modify the instruction and input—just leave the output blank.\n","\n","We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kR3gIAX-SM2q"},"outputs":[],"source":["FastVisionModel.for_inference(model) # Enable for inference!\n","\n","image = dataset[10][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"code","source":["# Step 1: Import required libraries\n","from PIL import Image\n","import io\n","import torch\n","from google.colab import files # For file upload in Colab\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","# Step 3: Upload image from user\n","print(\"Please upload an image file (e.g., .jpg, .png):\")\n","uploaded = files.upload() # Opens a file upload widget in Colab\n","\n","# Step 4: Load the uploaded image\n","if not uploaded:\n"," raise ValueError(\"No file uploaded. Please upload an image.\")\n","\n","# Get the first uploaded file\n","file_name = list(uploaded.keys())[0]\n","try:\n"," image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')\n","except Exception as e:\n"," raise ValueError(f\"Error loading image: {e}\")\n","\n","# Step 5: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 6: Prepare messages for the model\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","# Step 7: Apply chat template and prepare inputs\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","# Step 8: Generate output with text streaming\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(\n"," **inputs,\n"," streamer=text_streamer,\n"," max_new_tokens=512,\n"," use_cache=True,\n"," temperature=1.0,\n"," top_p=0.95,\n"," top_k=64\n",")"],"metadata":{"id":"oOyy5FUh8fBi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uMuVrWbjAzhc"},"source":["\n","### Saving, loading finetuned models\n","To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.\n","\n","**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"upcOlWe7A1vc"},"outputs":[],"source":["model.save_pretrained(\"lora_model\") # Local saving\n","processor.save_pretrained(\"lora_model\")\n","# model.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving\n","# processor.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving"]},{"cell_type":"markdown","metadata":{"id":"AEEcJ4qfC7Lp"},"source":["Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MKX_XKs_BNZR"},"outputs":[],"source":["if False:\n"," from unsloth import FastVisionModel\n","\n"," model, processor = FastVisionModel.from_pretrained(\n"," model_name=\"lora_model\", # YOUR MODEL YOU USED FOR TRAINING\n"," load_in_4bit=True, # Set to False for 16bit LoRA\n"," )\n"," FastVisionModel.for_inference(model) # Enable for inference!\n","\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","sample = dataset[1]\n","image = sample[\"image\"].convert(\"RGB\")\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\n"," \"type\": \"text\",\n"," \"text\": sample[\"text\"],\n"," },\n"," {\n"," \"type\": \"image\",\n"," },\n"," ],\n"," },\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor.tokenizer, skip_prompt=True)\n","_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"f422JgM9sdVT"},"source":["### Saving to float16 for VLLM\n","\n","We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."]}]}
\ No newline at end of file
|