import torch from PIL import Image import torchvision.transforms.functional as TVF import google.generativeai as genai import os GEMINI_API_KEY = os.environ.get("GOOGLE_API_KEY") if GEMINI_API_KEY: genai.configure(api_key=GEMINI_API_KEY) gemini_model = genai.GenerativeModel('gemini-1.5-flash') else: print("Warning: GOOGLE_API_KEY not found in environment variables") gemini_model = None CAPTION_TYPE_MAP = { "Descriptive": [ "Write a descriptive caption for this image in a formal tone.", "Write a descriptive caption for this image in a formal tone within {word_count} words.", "Write a {length} descriptive caption for this image in a formal tone.", ], "Training Prompt": [ "Write a stable diffusion prompt for this image.", "Write a stable diffusion prompt for this image within {word_count} words.", "Write a {length} stable diffusion prompt for this image.", ], "MidJourney": [ "Write a MidJourney prompt for this image.", "Write a MidJourney prompt for this image within {word_count} words.", "Write a {length} MidJourney prompt for this image.", ], } def get_image_features(input_image: Image.Image, clip_model, image_adapter=None): """Extract features from image using CLIP""" image = input_image.resize((384, 384), Image.LANCZOS) pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) with torch.no_grad(): vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True) if image_adapter is not None: embedded_images = image_adapter(vision_outputs.hidden_states) return embedded_images else: return vision_outputs.last_hidden_state def generate_caption(input_image: Image.Image, caption_type: str = "Descriptive", caption_length: str = "long", extra_options: list = None, name_input: str = "", custom_prompt: str = "", clip_model=None, image_adapter=None): """ Generate caption for an image using Gemini API. No Bullet points, proper punctuation, and no extra information. Args: input_image: PIL Image object caption_type: Type of caption ("Descriptive", "Training Prompt", "MidJourney") caption_length: Length specification ("any", "short", "long", etc. or number as string) extra_options: List of extra options name_input: Name to use for person/character in image custom_prompt: Custom prompt to override default settings clip_model: CLIP model (optional, for compatibility) image_adapter: Image adapter model (optional, for compatibility) Returns: tuple: (generated_caption) """ if gemini_model is None: return "Error: Gemini API key not configured", "Please set GEMINI_API_KEY environment variable" if input_image is None: return "Error: No image provided", "Please provide an image" if extra_options is None: extra_options = [] if torch.cuda.is_available(): torch.cuda.empty_cache() length = None if caption_length == "any" else caption_length if isinstance(length, str): try: length = int(length) except ValueError: pass if length is None: map_idx = 0 elif isinstance(length, int): map_idx = 1 elif isinstance(length, str): map_idx = 2 else: raise ValueError(f"Invalid caption length: {length}") prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx] if len(extra_options) > 0: prompt_str += " " + " ".join(extra_options) prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length) if custom_prompt.strip() != "": prompt_str = custom_prompt.strip() try: if clip_model is not None: image_features = get_image_features(input_image, clip_model, image_adapter) print(f"Extracted image features shape: {image_features.shape if hasattr(image_features, 'shape') else 'N/A'}") full_prompt = f"""You are a helpful image captioner. {prompt_str} Please analyze the provided image and generate a caption according to the instructions above. Just only the caption text, no additional information.""" response = gemini_model.generate_content([full_prompt, input_image]) if response.text: caption = response.text.strip() else: caption = "Failed to generate caption" except Exception as e: print(f"Error generating caption: {str(e)}") return prompt_str, f"Error: {str(e)}" return prompt_str, caption def caption_image_from_path(image_path: str, **kwargs): """Caption an image from file path""" image = Image.open(image_path) return generate_caption(image, **kwargs) def caption_image_simple(image_path: str, caption_type: str = "Descriptive"): """Simple interface to caption an image""" image = Image.open(image_path) prompt_used, caption = generate_caption(image, caption_type=caption_type) print(f"Caption: {caption}") return caption