Upload train_on_parquet.ipynb
Browse files- train_on_parquet.ipynb +1 -0
train_on_parquet.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/parquet_explorer.ipynb","timestamp":1754497857381},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754475181338},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754312448728},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754310418707},{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1754223895158},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1747490904984},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1740037333374},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1736477078136},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1725365086834}]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["Download a parquet file to your Google drive and load it from there into this notebook.\n","\n","Parquet files: https://huggingface.co/datasets/codeShare/chroma_prompts/tree/main\n","\n","E621 JSON files: https://huggingface.co/datasets/lodestones/e621-captions/tree/main"],"metadata":{"id":"LeCfcqgiQvCP"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"HFy5aDxM3G7O"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KYv7Y2gNPW_i"},"outputs":[],"source":["#@markdown Investigate a json file\n","\n","import json\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Initialize lists to store data\n","data = []\n","\n","# Read the .jsonl file line by line\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," # Parse each line as a JSON object\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","\n","# Convert the list of JSON objects to a Pandas DataFrame for easier exploration\n","df = pd.DataFrame(data)\n","\n","# Display basic information about the DataFrame\n","print(\"=== File Overview ===\")\n","print(f\"Number of records: {len(df)}\")\n","print(\"\\nColumn names:\")\n","print(df.columns.tolist())\n","print(\"\\nData types:\")\n","print(df.dtypes)\n","\n","# Display the first few rows\n","print(\"\\n=== First 5 Rows ===\")\n","print(df.head())\n","\n","# Display basic statistics\n","print(\"\\n=== Basic Statistics ===\")\n","print(df.describe(include='all'))\n","\n","# Check for missing values\n","print(\"\\n=== Missing Values ===\")\n","print(df.isnull().sum())\n","\n","# Optional: Display unique values in each column\n","print(\"\\n=== Unique Values per Column ===\")\n","for col in df.columns:\n"," print(f\"{col}: {df[col].nunique()} unique values\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dnIWOPPqSTnw"},"outputs":[],"source":["#@markdown Investigate a json file pt 2\n","\n","import json\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","from collections import Counter\n","import numpy as np\n","\n","# Set up plotting style\n","plt.style.use('default')\n","%matplotlib inline\n","\n","# Path to the uploaded .jsonl file\n","#file_path = ''\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","df = pd.DataFrame(data)\n","\n","# 1. Rating Distribution\n","print(\"=== Rating Distribution ===\")\n","rating_counts = df['rating'].value_counts()\n","plt.figure(figsize=(8, 5))\n","sns.barplot(x=rating_counts.index, y=rating_counts.values)\n","plt.title('Distribution of Image Ratings')\n","plt.xlabel('Rating')\n","plt.ylabel('Count')\n","plt.show()\n","print(rating_counts)\n","\n","# 2. Tag Analysis\n","print(\"\\n=== Top 10 Most Common Tags ===\")\n","# Combine all tags into a single list\n","all_tags = []\n","for tags in df['tag_string'].dropna():\n"," all_tags.extend(tags.split())\n","tag_counts = Counter(all_tags)\n","top_tags = pd.DataFrame(tag_counts.most_common(10), columns=['Tag', 'Count'])\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x='Count', y='Tag', data=top_tags)\n","plt.title('Top 10 Most Common Tags')\n","plt.show()\n","print(top_tags)\n","\n","# 3. Image Dimensions Analysis\n","print(\"\\n=== Image Dimensions Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'], df['image_height'], alpha=0.5, s=50)\n","plt.title('Image Width vs. Height')\n","plt.xlabel('Width (pixels)')\n","plt.ylabel('Height (pixels)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Median Width: {df['image_width'].median()}\")\n","print(f\"Median Height: {df['image_height'].median()}\")\n","print(f\"Aspect Ratio (Width/Height) Stats:\\n{df['image_width'].div(df['image_height']).describe()}\")\n","\n","# 4. Score and Voting Analysis\n","print(\"\\n=== Score and Voting Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['score'], bins=30, kde=True)\n","plt.title('Distribution of Image Scores')\n","plt.xlabel('Score')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Score Stats:\\n{df['score'].describe()}\")\n","print(f\"\\nCorrelation between Up Score and Down Score: {df['up_score'].corr(df['down_score'])}\")\n","\n","# 5. Summary Length Analysis\n","print(\"\\n=== Summary Length Analysis ===\")\n","df['summary_length'] = df['regular_summary'].dropna().apply(lambda x: len(str(x).split()))\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['summary_length'], bins=30, kde=True)\n","plt.title('Distribution of Regular Summary Word Counts')\n","plt.xlabel('Word Count')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Summary Length Stats:\\n{df['summary_length'].describe()}\")\n","\n","# 6. Missing Data Heatmap\n","print(\"\\n=== Missing Data Heatmap ===\")\n","plt.figure(figsize=(12, 8))\n","sns.heatmap(df.isnull(), cbar=False, cmap='viridis')\n","plt.title('Missing Data Heatmap')\n","plt.show()\n","\n","# 7. Source Platform Analysis\n","print(\"\\n=== Source Platform Analysis ===\")\n","# Extract domain from source URLs\n","df['source_domain'] = df['source'].dropna().str.extract(r'(https?://[^/]+)')\n","source_counts = df['source_domain'].value_counts().head(10)\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x=source_counts.values, y=source_counts.index)\n","plt.title('Top 10 Source Domains')\n","plt.xlabel('Count')\n","plt.ylabel('Domain')\n","plt.show()\n","print(source_counts)\n","\n","# 8. File Size vs. Image Dimensions\n","print(\"\\n=== File Size vs. Image Dimensions ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'] * df['image_height'], df['file_size'], alpha=0.5)\n","plt.title('File Size vs. Image Area')\n","plt.xlabel('Image Area (Width * Height)')\n","plt.ylabel('File Size (bytes)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Correlation between Image Area and File Size: {df['file_size'].corr(df['image_width'] * df['image_height'])}\")"]},{"cell_type":"code","source":["#@markdown convert E621 JSON to parquet file\n","\n","import json,os\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","df = pd.DataFrame(data)\n","\n","# Define columns that likely contain prompts/image descriptions\n","description_columns = [\n"," 'regular_summary',\n"," 'individual_parts',\n"," 'midjourney_style_summary',\n"," 'deviantart_commission_request',\n"," 'brief_summary'\n","]\n","\n","# Initialize a list to store all descriptions\n","all_descriptions = []\n","\n","# Iterate through each row and collect non-empty descriptions\n","for index, row in df.iterrows():\n"," record_descriptions = []\n"," for col in description_columns:\n"," if pd.notnull(row[col]) and row[col]: # Check for non-null and non-empty values\n"," record_descriptions.append(f\"{col}: {row[col]}\")\n"," if record_descriptions:\n"," all_descriptions.append({\n"," 'id': row['id'],\n"," 'descriptions': '; '.join(record_descriptions) # Join descriptions with semicolon\n"," })\n","\n","# Convert to DataFrame for Parquet\n","output_df = pd.DataFrame(all_descriptions)\n","\n","# Save to Parquet file\n","output_path = '' #@param {type:'string'}\n","output_df.to_parquet(output_path, index=False)\n","os.remove(f'{file_path}')\n","print(f\"\\nDescriptions have been saved to '{output_path}'.\")"],"metadata":{"id":"-NXBRSv4jsUS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Step 1: Mount Google Drive\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","#@markdown paste .parquet file stored on your Google Drive folder to see its characteristics\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Basic exploration of the Parquet file\n","print(\"First 5 rows of the dataset:\")\n","print(df.head())\n","\n","print(\"\\nDataset Info:\")\n","print(df.info())\n","\n","print(\"\\nBasic Statistics:\")\n","print(df.describe())\n","\n","print(\"\\nColumn Names:\")\n","print(df.columns.tolist())\n","\n","print(\"\\nMissing Values:\")\n","print(df.isnull().sum())\n","\n","# Optional: Display number of rows and columns\n","print(f\"\\nShape of the dataset: {df.shape}\")"],"metadata":{"id":"So-PKtbo5AVA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Read contents of a .parquet file\n","\n","# Import pandas\n","import pandas as pd\n","\n","# Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","parquet_column = 'descriptions' #@param {type:'string'}\n","# Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Set pandas display options to show full text without truncation\n","pd.set_option('display.max_colwidth', None) # Show full content of columns\n","pd.set_option('display.width', None) # Use full display width\n","\n","# Create sliders for selecting the range of captions\n","#@markdown Caption Range { run: \"auto\", display_mode: \"form\" }\n","start_at = 16814 #@param {type:\"slider\", min:0, max:33147, step:1}\n","range = 247 #@param {type:'slider',min:1,max:1000,step:1}\n","start_index = start_at\n","end_index = start_at + range\n","###@param {type:\"slider\", min:1, max:33148, step:1}\n","\n","include_either_words = '' #@param {type:'string', placeholder:'item1,item2...'}\n","#display_only = True #@param {type:'boolean'}\n","\n","_include_either_words = ''\n","for include_word in include_either_words.split(','):\n"," if include_word.strip()=='':continue\n"," _include_either_words= include_either_words + include_word.lower()+','+include_word.title() +','\n","#-----#\n","_include_either_words = _include_either_words[:len(_include_either_words)-1]\n","\n","\n","# Ensure end_index is greater than start_index and within bounds\n","if end_index <= start_index:\n"," print(\"Error: End index must be greater than start index.\")\n","elif end_index > len(df):\n"," print(f\"Error: End index cannot exceed {len(df)}. Setting to maximum value.\")\n"," end_index = len(df)\n","elif start_index < 0:\n"," print(\"Error: Start index cannot be negative. Setting to 0.\")\n"," start_index = 0\n","\n","# Display the selected range of captions\n","tmp =''\n","\n","categories= ['regular_summary:',';midjourney_style_summary:', 'individual_parts:']\n","\n","print(f\"\\nDisplaying captions from index {start_index} to {end_index-1}:\")\n","for index, caption in df[f'{parquet_column}'][start_index:end_index].items():\n"," for include_word in _include_either_words.split(','):\n"," found = True\n"," if (include_word.strip() in caption) or include_word.strip()=='':\n"," #----#\n"," if not found: continue\n"," tmp= caption + '\\n\\n'\n"," for category in categories:\n"," tmp = tmp.replace(f'{category}',f'\\n\\n{category}\\n')\n"," #----#\n"," print(f'Index {index}: {tmp}')\n"],"metadata":{"id":"wDhyb8M_7pkD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Build a dataset for training using a .parquet file\n","\n","num_dataset_items = 200 #@param {type:'slider',max:1000}\n","\n","\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import numpy as np\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Randomly select 300 rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items*1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have 200 valid images\n"," break\n"," url = row['url']\n"," caption = row['original_caption'] + ', ' + row['vlm_caption'].replace('This image displays:','').replace('This image displays','')\n","\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly 200 if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","# Optional: Save the dataset to disk (if needed)\n","#dataset.save_to_disk('/content/drive/MyDrive/Chroma prompts/custom_dataset')"],"metadata":{"id":"XZvpJ5zw0fzR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset_name=''#@param {type:'string'}\n","\n","if dataset_name.strip()=='':\n"," dataset_name='my_dataset'\n","\n","\n","dataset.save_to_disk(f'/content/drive/MyDrive/{dataset_name}')\n","\n","\n"],"metadata":{"id":"iTyxazlM1OAn"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"sQmoYDLHUXxF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"jFnWBQHa142R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown load two .parquet datasets for merging\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset1_path = '' #@param {type: 'string'}\n","\n","dataset2_path = '' #@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset1 = load_from_disk(dataset1_path)\n"," dataset2 = load_from_disk(dataset2_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the dataset\n","print(dataset1)\n","print(dataset2)\n","\n","# Step 7: Example of accessing an image and text\n","#print(\"\\nExample of accessing first item:\")\n","#print(\"Text:\", redcaps_dataset['text'][0])\n","#print(\"Image type:\", type(dataset['image'][0]))\n","#print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"LoCcBJqs4pzL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"AmLgPcrdRqCJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"X5HLZqjTRt7L"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Merge the two .parquet files into one\n","\n","# Step 1: Import required libraries\n","from datasets import load_from_disk, concatenate_datasets\n","from google.colab import drive\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","drive.mount('/content/drive')\n","\n","# Step 3: Define paths for the datasets\n","dataset1_path = '' #@param {type:'string'}\n","dataset2_path = '' #@param {type:'string'}\n","merged_dataset_path = '' #@param {type:'string'}\n","\n","# Step 4: Load the datasets\n","try:\n"," dataset1 = load_from_disk(dataset1_path)\n"," dataset2 = load_from_disk(dataset2_path)\n"," print(\"Datasets loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading datasets: {e}\")\n"," raise\n","\n","# Step 5: Verify the datasets\n","print(\"Dataset 1:\", dataset1)\n","print(\"Dataset 2:\", dataset2)\n","\n","# Step 6: Merge the datasets\n","try:\n"," merged_dataset = concatenate_datasets([dataset1, dataset2])\n"," print(\"Datasets merged successfully!\")\n","except Exception as e:\n"," print(f\"Error merging datasets: {e}\")\n"," raise\n","\n","# Step 7: Verify the merged dataset\n","print(\"Merged Dataset:\", merged_dataset)\n","\n","# Step 8: Save the merged dataset to Google Drive\n","try:\n"," merged_dataset.save_to_disk(merged_dataset_path)\n"," print(f\"Merged dataset saved successfully to {merged_dataset_path}\")\n","except Exception as e:\n"," print(f\"Error saving merged dataset: {e}\")\n"," raise\n","\n","# Step 9: Optional - Verify the saved dataset by loading it back\n","try:\n"," loaded_merged_dataset = load_from_disk(merged_dataset_path)\n"," print(\"Saved merged dataset loaded successfully for verification:\")\n"," print(loaded_merged_dataset)\n","except Exception as e:\n"," print(f\"Error loading saved merged dataset: {e}\")\n"," raise"],"metadata":{"id":"HF_cmJu1EMJV"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["🔄 Change to T4 Runtime : Past this point you can train a LoRa on the Dataset , but you need to change the runtime to T4 for that first\n","\n","See original file at:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb"],"metadata":{"id":"0Kmf1OP6Se4Q"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"ESLqweKz4xM_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Test the merged dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = ''#@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset = load_from_disk(dataset_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the dataset\n","print(dataset)\n","\n","# Step 7: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"xUA37h2APkWc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"4hCnrtv6R9B1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"MSetS3MCR2qJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9CBpiISFa6C"},"source":["To format the dataset, all vision fine-tuning tasks should follow this format:\n","\n","```python\n","[\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n","]\n","```"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"oPXzJZzHEgXe"},"outputs":[],"source":["#@markdown Convert the merged dataset to the 'correct' format for training the Gemma LoRa model\n","\n","instruction = \"Describe this image.\" # <- Select the prompt for your use case here\n","\n","def convert_to_conversation(sample):\n"," conversation = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n"," {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": sample[\"text\"]}]},\n"," ]\n"," return {\"messages\": conversation}\n","pass"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gFW2qXIr7Ezy"},"outputs":[],"source":["converted_dataset = [convert_to_conversation(sample) for sample in dataset]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gGFzmplrEy9I"},"outputs":[],"source":["converted_dataset[0]"]},{"cell_type":"markdown","metadata":{"id":"529CsYil1qc6"},"source":["### Installation"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9vJOucOw1qc6"},"outputs":[],"source":["%%capture\n","import os\n","if \"COLAB_\" not in \"\".join(os.environ.keys()):\n"," !pip install unsloth\n","else:\n"," # Do this only in Colab notebooks! Otherwise use pip install unsloth\n"," !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n"," !pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n"," !pip install --no-deps unsloth"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"QmUBVEnvCDJv"},"outputs":[],"source":["from unsloth import FastVisionModel # FastLanguageModel for LLMs\n","import torch\n","\n","# 4bit pre quantized models we support for 4x faster downloading + no OOMs.\n","fourbit_models = [\n"," \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\", # Llama 3.2 vision support\n"," \"unsloth/Llama-3.2-11B-Vision-bnb-4bit\",\n"," \"unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit\", # Can fit in a 80GB card!\n"," \"unsloth/Llama-3.2-90B-Vision-bnb-4bit\",\n","\n"," \"unsloth/Pixtral-12B-2409-bnb-4bit\", # Pixtral fits in 16GB!\n"," \"unsloth/Pixtral-12B-Base-2409-bnb-4bit\", # Pixtral base model\n","\n"," \"unsloth/Qwen2-VL-2B-Instruct-bnb-4bit\", # Qwen2 VL support\n"," \"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit\",\n"," \"unsloth/Qwen2-VL-72B-Instruct-bnb-4bit\",\n","\n"," \"unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit\", # Any Llava variant works!\n"," \"unsloth/llava-1.5-7b-hf-bnb-4bit\",\n","] # More models at https://huggingface.co/unsloth\n","\n","model, processor = FastVisionModel.from_pretrained(\n"," \"unsloth/gemma-3-4b-pt\",\n"," load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.\n"," use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for long context\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bEzvL7Sm1CrS"},"outputs":[],"source":["from unsloth import get_chat_template\n","\n","processor = get_chat_template(\n"," processor,\n"," \"gemma-3\"\n",")"]},{"cell_type":"markdown","metadata":{"id":"SXd9bTZd1aaL"},"source":["We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.\n","\n","**[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6bZsfBuZDeCL"},"outputs":[],"source":["model = FastVisionModel.get_peft_model(\n"," model,\n"," finetune_vision_layers = True, # False if not finetuning vision layers\n"," finetune_language_layers = True, # False if not finetuning language layers\n"," finetune_attention_modules = True, # False if not finetuning attention layers\n"," finetune_mlp_modules = True, # False if not finetuning MLP layers\n","\n"," r = 16, # The larger, the higher the accuracy, but might overfit\n"," lora_alpha = 16, # Recommended alpha == r at least\n"," lora_dropout = 0,\n"," bias = \"none\",\n"," random_state = 3408,\n"," use_rslora = False, # We support rank stabilized LoRA\n"," loftq_config = None, # And LoftQ\n"," target_modules = \"all-linear\", # Optional now! Can specify a list if needed\n"," modules_to_save=[\n"," \"lm_head\",\n"," \"embed_tokens\",\n"," ],\n",")"]},{"cell_type":"markdown","metadata":{"id":"FecKS-dA82f5"},"source":["Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vcat4UxA81vr"},"outputs":[],"source":["FastVisionModel.for_inference(model) # Enable for inference!\n","\n","image = dataset[2][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"idAEIeSQ3xdS"},"source":["<a name=\"Train\"></a>\n","### Train the model\n","Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!\n","\n","We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"95_Nn-89DhsL"},"outputs":[],"source":["from unsloth.trainer import UnslothVisionDataCollator\n","from trl import SFTTrainer, SFTConfig\n","\n","FastVisionModel.for_training(model) # Enable for training!\n","\n","trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=converted_dataset,\n"," processing_class=processor.tokenizer,\n"," data_collator=UnslothVisionDataCollator(model, processor),\n"," args = SFTConfig(\n"," per_device_train_batch_size = 1,\n"," gradient_accumulation_steps = 4,\n"," gradient_checkpointing = True,\n","\n"," # use reentrant checkpointing\n"," gradient_checkpointing_kwargs = {\"use_reentrant\": False},\n"," max_grad_norm = 0.3, # max gradient norm based on QLoRA paper\n"," warmup_ratio = 0.03,\n"," #max_steps = 30,\n"," num_train_epochs = 5, # Set this instead of max_steps for full training runs\n"," learning_rate = 2e-4,\n"," logging_steps = 1,\n"," save_strategy=\"steps\",\n"," optim = \"adamw_torch_fused\",\n"," weight_decay = 0.01,\n"," lr_scheduler_type = \"cosine\",\n"," seed = 3407,\n"," output_dir = \"outputs\",\n"," report_to = \"none\", # For Weights and Biases\n","\n"," # You MUST put the below items for vision finetuning:\n"," remove_unused_columns = False,\n"," dataset_text_field = \"\",\n"," dataset_kwargs = {\"skip_prepare_dataset\": True},\n"," max_length = 2048,\n"," )\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"2ejIt2xSNKKp"},"outputs":[],"source":["# @title Show current memory stats\n","gpu_stats = torch.cuda.get_device_properties(0)\n","start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n","print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n","print(f\"{start_gpu_memory} GB of memory reserved.\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yqxqAZ7KJ4oL"},"outputs":[],"source":["trainer_stats = trainer.train()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"pCqnaKmlO1U9"},"outputs":[],"source":["# @title Show final memory and time stats\n","used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n","used_percentage = round(used_memory / max_memory * 100, 3)\n","lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n","print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n","print(\n"," f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",")\n","print(f\"Peak reserved memory = {used_memory} GB.\")\n","print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n","print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n","print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"]},{"cell_type":"markdown","metadata":{"id":"ekOmTR1hSNcr"},"source":["<a name=\"Inference\"></a>\n","### Inference\n","Let's run the model! You can modify the instruction and input—just leave the output blank.\n","\n","We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kR3gIAX-SM2q"},"outputs":[],"source":["FastVisionModel.for_inference(model) # Enable for inference!\n","\n","image = dataset[10][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"code","source":["# Step 1: Import required libraries\n","from PIL import Image\n","import io\n","import torch\n","from google.colab import files # For file upload in Colab\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","# Step 3: Upload image from user\n","print(\"Please upload an image file (e.g., .jpg, .png):\")\n","uploaded = files.upload() # Opens a file upload widget in Colab\n","\n","# Step 4: Load the uploaded image\n","if not uploaded:\n"," raise ValueError(\"No file uploaded. Please upload an image.\")\n","\n","# Get the first uploaded file\n","file_name = list(uploaded.keys())[0]\n","try:\n"," image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')\n","except Exception as e:\n"," raise ValueError(f\"Error loading image: {e}\")\n","\n","# Step 5: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 6: Prepare messages for the model\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","# Step 7: Apply chat template and prepare inputs\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","# Step 8: Generate output with text streaming\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(\n"," **inputs,\n"," streamer=text_streamer,\n"," max_new_tokens=512,\n"," use_cache=True,\n"," temperature=1.0,\n"," top_p=0.95,\n"," top_k=64\n",")"],"metadata":{"id":"oOyy5FUh8fBi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uMuVrWbjAzhc"},"source":["<a name=\"Save\"></a>\n","### Saving, loading finetuned models\n","To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.\n","\n","**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"upcOlWe7A1vc"},"outputs":[],"source":["model.save_pretrained(\"lora_model\") # Local saving\n","processor.save_pretrained(\"lora_model\")\n","# model.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving\n","# processor.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving"]},{"cell_type":"markdown","metadata":{"id":"AEEcJ4qfC7Lp"},"source":["Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MKX_XKs_BNZR"},"outputs":[],"source":["if False:\n"," from unsloth import FastVisionModel\n","\n"," model, processor = FastVisionModel.from_pretrained(\n"," model_name=\"lora_model\", # YOUR MODEL YOU USED FOR TRAINING\n"," load_in_4bit=True, # Set to False for 16bit LoRA\n"," )\n"," FastVisionModel.for_inference(model) # Enable for inference!\n","\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","sample = dataset[1]\n","image = sample[\"image\"].convert(\"RGB\")\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\n"," \"type\": \"text\",\n"," \"text\": sample[\"text\"],\n"," },\n"," {\n"," \"type\": \"image\",\n"," },\n"," ],\n"," },\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor.tokenizer, skip_prompt=True)\n","_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"f422JgM9sdVT"},"source":["### Saving to float16 for VLLM\n","\n","We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."]}]}
|