codeShare commited on
Commit
2c4489d
·
verified ·
1 Parent(s): dbfbb6f

Upload gemma_image_captioner.ipynb

Browse files
Files changed (1) hide show
  1. gemma_image_captioner.ipynb +1 -0
gemma_image_captioner.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells":[{"cell_type":"markdown","source":["This notebook creates captions from images using a lora adaptation of google gemma 3 LLM. This lora is very basic as it has only been trained in 400 images of reddit posts and e621 NSFW posts over 5 epochs.\n","\n","Created by Adcom: https://tensor.art/u/754389913230900026"],"metadata":{"id":"HbBHYqQY8iHH"}},{"cell_type":"markdown","metadata":{"id":"529CsYil1qc6"},"source":["### Installation"]},{"cell_type":"code","execution_count":2,"metadata":{"id":"9vJOucOw1qc6","executionInfo":{"status":"ok","timestamp":1754490549383,"user_tz":-120,"elapsed":33363,"user":{"displayName":"fukU Google","userId":"02763165356193834046"}}},"outputs":[],"source":["%%capture\n","import os\n","if \"COLAB_\" not in \"\".join(os.environ.keys()):\n"," !pip install unsloth\n","else:\n"," # Do this only in Colab notebooks! Otherwise use pip install unsloth\n"," !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n"," !pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n"," !pip install --no-deps unsloth"]},{"cell_type":"code","source":["if True:\n"," from unsloth import FastVisionModel\n","\n"," model, processor = FastVisionModel.from_pretrained(\n"," model_name='codeShare/flux_chroma_image_captioner', # YOUR MODEL YOU USED FOR TRAINING\n"," load_in_4bit=True, # Set to False for 16bit LoRA\n"," )\n"," FastVisionModel.for_inference(model) # Enable for inference!"],"metadata":{"id":"9yu3CI6SsjN7"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":5,"metadata":{"id":"bEzvL7Sm1CrS","executionInfo":{"status":"ok","timestamp":1754490787703,"user_tz":-120,"elapsed":9,"user":{"displayName":"fukU Google","userId":"02763165356193834046"}}},"outputs":[],"source":["from unsloth import get_chat_template\n","\n","processor = get_chat_template(\n"," processor,\n"," \"gemma-3\"\n",")"]},{"cell_type":"markdown","source":["A prompt to upload an image for processing will appear when running this cell"],"metadata":{"id":"DmbcTDgq8Bjg"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"oOyy5FUh8fBi"},"outputs":[],"source":["# Step 1: Import required libraries\n","from PIL import Image\n","import io\n","import torch\n","from google.colab import files # For file upload in Colab\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","# Step 3: Upload image from user\n","print(\"Please upload an image file (e.g., .jpg, .png):\")\n","uploaded = files.upload() # Opens a file upload widget in Colab\n","\n","# Step 4: Load the uploaded image\n","if not uploaded:\n"," raise ValueError(\"No file uploaded. Please upload an image.\")\n","\n","# Get the first uploaded file\n","file_name = list(uploaded.keys())[0]\n","try:\n"," image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')\n","except Exception as e:\n"," raise ValueError(f\"Error loading image: {e}\")\n","\n","# Step 5: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 6: Prepare messages for the model\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","# Step 7: Apply chat template and prepare inputs\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","# Step 8: Generate output with text streaming\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(\n"," **inputs,\n"," streamer=text_streamer,\n"," max_new_tokens=512,\n"," use_cache=True,\n"," temperature=1.0,\n"," top_p=0.95,\n"," top_k=64\n",")"]},{"cell_type":"markdown","source":["<---- Upload a set if images to /content/ prior to running this cell. You can also open a .zip file and rename the folder with images as '/content/input'"],"metadata":{"id":"CrqNw_3O7np5"}},{"cell_type":"code","source":["# Step 1: Import required libraries\n","from PIL import Image\n","import torch\n","import os\n","from pathlib import Path\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","# Step 3: Define input and output directories\n","input_dirs = ['/content/', '/content/input/']\n","output_dir = '/content/output/'\n","\n","# Create output directory if it doesn't exist\n","os.makedirs(output_dir, exist_ok=True)\n","\n","# Step 4: Define supported image extensions\n","image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif'}\n","\n","# Step 5: Collect all image files from input directories\n","image_files = []\n","for input_dir in input_dirs:\n"," if os.path.exists(input_dir):\n"," for file in Path(input_dir).rglob('*'):\n"," if file.suffix.lower() in image_extensions:\n"," image_files.append(file)\n"," else:\n"," print(f\"Directory {input_dir} does not exist, skipping...\")\n","\n","if not image_files:\n"," raise ValueError(\"No images found in /content/ or /content/input/\")\n","\n","# Step 6: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 7: Process each image\n","for image_path in image_files:\n"," try:\n"," # Load image\n"," image = Image.open(image_path).convert('RGB')\n","\n"," # Prepare messages for the model\n"," messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n"," ]\n","\n"," # Apply chat template and prepare inputs\n"," input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n"," inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n"," ).to(\"cuda\")\n","\n"," # Generate output without streaming\n"," print(f\"\\nProcessing {image_path.name}...\")\n"," result = model.generate(\n"," **inputs,\n"," max_new_tokens=512,\n"," use_cache=True,\n"," temperature=1.0,\n"," top_p=0.95,\n"," top_k=64\n"," )\n","\n"," # Decode the generated text\n"," caption = processor.decode(result[0], skip_special_tokens=True).strip()\n","\n"," # Print caption with extra whitespace for easy selection\n"," print(f\"\\n=== Caption for {image_path.name} ===\\n\\n{caption}\\n\\n====================\\n\")\n","\n"," # Save image and caption\n"," output_image_path = os.path.join(output_dir, image_path.name)\n"," output_caption_path = os.path.join(output_dir, f\"{image_path.stem}.txt\")\n","\n"," # Copy original image to output directory\n"," image.save(output_image_path)\n","\n"," # Save caption to text file\n"," with open(output_caption_path, 'w') as f:\n"," f.write(caption)\n","\n"," print(f\"Saved image and caption for {image_path.name}\")\n","\n"," # Delete the original image if it's in /content/ (but not /content/input/)\n"," if str(image_path).startswith('/content/') and not str(image_path).startswith('/content/input/'):\n"," try:\n"," os.remove(image_path)\n"," print(f\"Deleted original image: {image_path}\")\n"," except Exception as e:\n"," print(f\"Error deleting {image_path}: {e}\")\n","\n"," except Exception as e:\n"," print(f\"Error processing {image_path.name}: {e}\")\n","\n","print(f\"\\nProcessing complete. Output saved to {output_dir}\")"],"metadata":{"id":"MQAp389z30Jd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @markdown 💾 Create .zip file of output to /content/\n","output_filename ='' #@param {type:'string'}\n","if output_filename.trim()=='':\n"," output_filename = 'chroma_prompts.zip'\n","#-----#\n","import shutil\n","shutil.make_archive('chroma_prompts', 'zip', 'output')\n","\n"],"metadata":{"id":"vfOXO0uB5pJ0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","# @markdown 🧹Clear all images/.txt files/.zip files from /content/\n","import os\n","from pathlib import Path\n","\n","# Define the directory to clean\n","directory_to_clean = '/content/'\n","\n","# Define supported image and text extensions\n","extensions_to_delete = {'.zip','.webp' ,'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.txt'}\n","\n","# Iterate through files in the directory and delete those with specified extensions\n","for file in Path(directory_to_clean).iterdir():\n"," if file.suffix.lower() in extensions_to_delete:\n"," try:\n"," os.remove(file)\n"," print(f\"Deleted: {file}\")\n"," except Exception as e:\n"," print(f\"Error deleting {file}: {e}\")\n","\n","print(f\"\\nCleaning of {directory_to_clean} complete.\")"],"metadata":{"id":"wUpoo2uI6TZA"},"execution_count":null,"outputs":[]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[{"file_id":"https://huggingface.co/datasets/codeShare/gemma_training/blob/main/Gemma3_(4B)-Vision.ipynb","timestamp":1754479907506},{"file_id":"https://huggingface.co/datasets/codeShare/gemma_training/blob/main/Gemma3_(4B)-Vision.ipynb","timestamp":1754479614873},{"file_id":"https://github.com/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb","timestamp":1754476728770}]},"kernelspec":{"display_name":".venv","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.13.3"}},"nbformat":4,"nbformat_minor":0}