Image_Prompting-and-Captioning / generate_caption.py
Adieee5's picture
Update generate_caption.py
4aa1ed7 verified
import torch
from PIL import Image
import torchvision.transforms.functional as TVF
import google.generativeai as genai
import os
GEMINI_API_KEY = os.environ.get("GOOGLE_API_KEY")
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel('gemini-1.5-flash')
else:
print("Warning: GOOGLE_API_KEY not found in environment variables")
gemini_model = None
CAPTION_TYPE_MAP = {
"Descriptive": [
"Write a descriptive caption for this image in a formal tone.",
"Write a descriptive caption for this image in a formal tone within {word_count} words.",
"Write a {length} descriptive caption for this image in a formal tone.",
],
"Training Prompt": [
"Write a stable diffusion prompt for this image.",
"Write a stable diffusion prompt for this image within {word_count} words.",
"Write a {length} stable diffusion prompt for this image.",
],
"MidJourney": [
"Write a MidJourney prompt for this image.",
"Write a MidJourney prompt for this image within {word_count} words.",
"Write a {length} MidJourney prompt for this image.",
],
}
def get_image_features(input_image: Image.Image, clip_model, image_adapter=None):
"""Extract features from image using CLIP"""
image = input_image.resize((384, 384), Image.LANCZOS)
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
with torch.no_grad():
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
if image_adapter is not None:
embedded_images = image_adapter(vision_outputs.hidden_states)
return embedded_images
else:
return vision_outputs.last_hidden_state
def generate_caption(input_image: Image.Image,
caption_type: str = "Descriptive",
caption_length: str = "long",
extra_options: list = None,
name_input: str = "",
custom_prompt: str = "",
clip_model=None,
image_adapter=None):
"""
Generate caption for an image using Gemini API. No Bullet points, proper punctuation, and no extra information.
Args:
input_image: PIL Image object
caption_type: Type of caption ("Descriptive", "Training Prompt", "MidJourney")
caption_length: Length specification ("any", "short", "long", etc. or number as string)
extra_options: List of extra options
name_input: Name to use for person/character in image
custom_prompt: Custom prompt to override default settings
clip_model: CLIP model (optional, for compatibility)
image_adapter: Image adapter model (optional, for compatibility)
Returns:
tuple: (generated_caption)
"""
if gemini_model is None:
return "Error: Gemini API key not configured", "Please set GEMINI_API_KEY environment variable"
if input_image is None:
return "Error: No image provided", "Please provide an image"
if extra_options is None:
extra_options = []
if torch.cuda.is_available():
torch.cuda.empty_cache()
length = None if caption_length == "any" else caption_length
if isinstance(length, str):
try:
length = int(length)
except ValueError:
pass
if length is None:
map_idx = 0
elif isinstance(length, int):
map_idx = 1
elif isinstance(length, str):
map_idx = 2
else:
raise ValueError(f"Invalid caption length: {length}")
prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
if len(extra_options) > 0:
prompt_str += " " + " ".join(extra_options)
prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
if custom_prompt.strip() != "":
prompt_str = custom_prompt.strip()
try:
if clip_model is not None:
image_features = get_image_features(input_image, clip_model, image_adapter)
print(f"Extracted image features shape: {image_features.shape if hasattr(image_features, 'shape') else 'N/A'}")
full_prompt = f"""You are a helpful image captioner.
{prompt_str}
Please analyze the provided image and generate a caption according to the instructions above. Just only the caption text, no additional information."""
response = gemini_model.generate_content([full_prompt, input_image])
if response.text:
caption = response.text.strip()
else:
caption = "Failed to generate caption"
except Exception as e:
print(f"Error generating caption: {str(e)}")
return prompt_str, f"Error: {str(e)}"
return prompt_str, caption
def caption_image_from_path(image_path: str, **kwargs):
"""Caption an image from file path"""
image = Image.open(image_path)
return generate_caption(image, **kwargs)
def caption_image_simple(image_path: str, caption_type: str = "Descriptive"):
"""Simple interface to caption an image"""
image = Image.open(image_path)
prompt_used, caption = generate_caption(image, caption_type=caption_type)
print(f"Caption: {caption}")
return caption