|
import torch |
|
from PIL import Image |
|
import torchvision.transforms.functional as TVF |
|
import google.generativeai as genai |
|
import os |
|
|
|
GEMINI_API_KEY = os.environ.get("GOOGLE_API_KEY") |
|
|
|
if GEMINI_API_KEY: |
|
genai.configure(api_key=GEMINI_API_KEY) |
|
gemini_model = genai.GenerativeModel('gemini-1.5-flash') |
|
else: |
|
print("Warning: GOOGLE_API_KEY not found in environment variables") |
|
gemini_model = None |
|
|
|
CAPTION_TYPE_MAP = { |
|
"Descriptive": [ |
|
"Write a descriptive caption for this image in a formal tone.", |
|
"Write a descriptive caption for this image in a formal tone within {word_count} words.", |
|
"Write a {length} descriptive caption for this image in a formal tone.", |
|
], |
|
"Training Prompt": [ |
|
"Write a stable diffusion prompt for this image.", |
|
"Write a stable diffusion prompt for this image within {word_count} words.", |
|
"Write a {length} stable diffusion prompt for this image.", |
|
], |
|
"MidJourney": [ |
|
"Write a MidJourney prompt for this image.", |
|
"Write a MidJourney prompt for this image within {word_count} words.", |
|
"Write a {length} MidJourney prompt for this image.", |
|
], |
|
} |
|
|
|
def get_image_features(input_image: Image.Image, clip_model, image_adapter=None): |
|
"""Extract features from image using CLIP""" |
|
image = input_image.resize((384, 384), Image.LANCZOS) |
|
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 |
|
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) |
|
|
|
with torch.no_grad(): |
|
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True) |
|
|
|
if image_adapter is not None: |
|
embedded_images = image_adapter(vision_outputs.hidden_states) |
|
return embedded_images |
|
else: |
|
return vision_outputs.last_hidden_state |
|
|
|
|
|
def generate_caption(input_image: Image.Image, |
|
caption_type: str = "Descriptive", |
|
caption_length: str = "long", |
|
extra_options: list = None, |
|
name_input: str = "", |
|
custom_prompt: str = "", |
|
clip_model=None, |
|
image_adapter=None): |
|
""" |
|
Generate caption for an image using Gemini API. No Bullet points, proper punctuation, and no extra information. |
|
|
|
Args: |
|
input_image: PIL Image object |
|
caption_type: Type of caption ("Descriptive", "Training Prompt", "MidJourney") |
|
caption_length: Length specification ("any", "short", "long", etc. or number as string) |
|
extra_options: List of extra options |
|
name_input: Name to use for person/character in image |
|
custom_prompt: Custom prompt to override default settings |
|
clip_model: CLIP model (optional, for compatibility) |
|
image_adapter: Image adapter model (optional, for compatibility) |
|
|
|
Returns: |
|
tuple: (generated_caption) |
|
""" |
|
if gemini_model is None: |
|
return "Error: Gemini API key not configured", "Please set GEMINI_API_KEY environment variable" |
|
|
|
if input_image is None: |
|
return "Error: No image provided", "Please provide an image" |
|
|
|
if extra_options is None: |
|
extra_options = [] |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
length = None if caption_length == "any" else caption_length |
|
|
|
if isinstance(length, str): |
|
try: |
|
length = int(length) |
|
except ValueError: |
|
pass |
|
|
|
if length is None: |
|
map_idx = 0 |
|
elif isinstance(length, int): |
|
map_idx = 1 |
|
elif isinstance(length, str): |
|
map_idx = 2 |
|
else: |
|
raise ValueError(f"Invalid caption length: {length}") |
|
|
|
prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx] |
|
|
|
if len(extra_options) > 0: |
|
prompt_str += " " + " ".join(extra_options) |
|
|
|
prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length) |
|
|
|
if custom_prompt.strip() != "": |
|
prompt_str = custom_prompt.strip() |
|
|
|
try: |
|
if clip_model is not None: |
|
image_features = get_image_features(input_image, clip_model, image_adapter) |
|
print(f"Extracted image features shape: {image_features.shape if hasattr(image_features, 'shape') else 'N/A'}") |
|
|
|
full_prompt = f"""You are a helpful image captioner. |
|
|
|
{prompt_str} |
|
|
|
Please analyze the provided image and generate a caption according to the instructions above. Just only the caption text, no additional information.""" |
|
|
|
response = gemini_model.generate_content([full_prompt, input_image]) |
|
|
|
if response.text: |
|
caption = response.text.strip() |
|
else: |
|
caption = "Failed to generate caption" |
|
|
|
except Exception as e: |
|
print(f"Error generating caption: {str(e)}") |
|
return prompt_str, f"Error: {str(e)}" |
|
|
|
return prompt_str, caption |
|
|
|
def caption_image_from_path(image_path: str, **kwargs): |
|
"""Caption an image from file path""" |
|
image = Image.open(image_path) |
|
return generate_caption(image, **kwargs) |
|
|
|
|
|
def caption_image_simple(image_path: str, caption_type: str = "Descriptive"): |
|
"""Simple interface to caption an image""" |
|
image = Image.open(image_path) |
|
prompt_used, caption = generate_caption(image, caption_type=caption_type) |
|
print(f"Caption: {caption}") |
|
return caption |
|
|