File size: 5,334 Bytes
19dc712 4aa1ed7 19dc712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch
from PIL import Image
import torchvision.transforms.functional as TVF
import google.generativeai as genai
import os
GEMINI_API_KEY = os.environ.get("GOOGLE_API_KEY")
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel('gemini-1.5-flash')
else:
print("Warning: GOOGLE_API_KEY not found in environment variables")
gemini_model = None
CAPTION_TYPE_MAP = {
"Descriptive": [
"Write a descriptive caption for this image in a formal tone.",
"Write a descriptive caption for this image in a formal tone within {word_count} words.",
"Write a {length} descriptive caption for this image in a formal tone.",
],
"Training Prompt": [
"Write a stable diffusion prompt for this image.",
"Write a stable diffusion prompt for this image within {word_count} words.",
"Write a {length} stable diffusion prompt for this image.",
],
"MidJourney": [
"Write a MidJourney prompt for this image.",
"Write a MidJourney prompt for this image within {word_count} words.",
"Write a {length} MidJourney prompt for this image.",
],
}
def get_image_features(input_image: Image.Image, clip_model, image_adapter=None):
"""Extract features from image using CLIP"""
image = input_image.resize((384, 384), Image.LANCZOS)
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
with torch.no_grad():
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
if image_adapter is not None:
embedded_images = image_adapter(vision_outputs.hidden_states)
return embedded_images
else:
return vision_outputs.last_hidden_state
def generate_caption(input_image: Image.Image,
caption_type: str = "Descriptive",
caption_length: str = "long",
extra_options: list = None,
name_input: str = "",
custom_prompt: str = "",
clip_model=None,
image_adapter=None):
"""
Generate caption for an image using Gemini API. No Bullet points, proper punctuation, and no extra information.
Args:
input_image: PIL Image object
caption_type: Type of caption ("Descriptive", "Training Prompt", "MidJourney")
caption_length: Length specification ("any", "short", "long", etc. or number as string)
extra_options: List of extra options
name_input: Name to use for person/character in image
custom_prompt: Custom prompt to override default settings
clip_model: CLIP model (optional, for compatibility)
image_adapter: Image adapter model (optional, for compatibility)
Returns:
tuple: (generated_caption)
"""
if gemini_model is None:
return "Error: Gemini API key not configured", "Please set GEMINI_API_KEY environment variable"
if input_image is None:
return "Error: No image provided", "Please provide an image"
if extra_options is None:
extra_options = []
if torch.cuda.is_available():
torch.cuda.empty_cache()
length = None if caption_length == "any" else caption_length
if isinstance(length, str):
try:
length = int(length)
except ValueError:
pass
if length is None:
map_idx = 0
elif isinstance(length, int):
map_idx = 1
elif isinstance(length, str):
map_idx = 2
else:
raise ValueError(f"Invalid caption length: {length}")
prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
if len(extra_options) > 0:
prompt_str += " " + " ".join(extra_options)
prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
if custom_prompt.strip() != "":
prompt_str = custom_prompt.strip()
try:
if clip_model is not None:
image_features = get_image_features(input_image, clip_model, image_adapter)
print(f"Extracted image features shape: {image_features.shape if hasattr(image_features, 'shape') else 'N/A'}")
full_prompt = f"""You are a helpful image captioner.
{prompt_str}
Please analyze the provided image and generate a caption according to the instructions above. Just only the caption text, no additional information."""
response = gemini_model.generate_content([full_prompt, input_image])
if response.text:
caption = response.text.strip()
else:
caption = "Failed to generate caption"
except Exception as e:
print(f"Error generating caption: {str(e)}")
return prompt_str, f"Error: {str(e)}"
return prompt_str, caption
def caption_image_from_path(image_path: str, **kwargs):
"""Caption an image from file path"""
image = Image.open(image_path)
return generate_caption(image, **kwargs)
def caption_image_simple(image_path: str, caption_type: str = "Descriptive"):
"""Simple interface to caption an image"""
image = Image.open(image_path)
prompt_used, caption = generate_caption(image, caption_type=caption_type)
print(f"Caption: {caption}")
return caption
|