File size: 5,334 Bytes
19dc712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa1ed7
19dc712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
from PIL import Image
import torchvision.transforms.functional as TVF
import google.generativeai as genai
import os

GEMINI_API_KEY = os.environ.get("GOOGLE_API_KEY")

if GEMINI_API_KEY:
    genai.configure(api_key=GEMINI_API_KEY)
    gemini_model = genai.GenerativeModel('gemini-1.5-flash')
else:
    print("Warning: GOOGLE_API_KEY not found in environment variables")
    gemini_model = None

CAPTION_TYPE_MAP = {
    "Descriptive": [
        "Write a descriptive caption for this image in a formal tone.",
        "Write a descriptive caption for this image in a formal tone within {word_count} words.",
        "Write a {length} descriptive caption for this image in a formal tone.",
    ],
    "Training Prompt": [
        "Write a stable diffusion prompt for this image.",
        "Write a stable diffusion prompt for this image within {word_count} words.",
        "Write a {length} stable diffusion prompt for this image.",
    ],
    "MidJourney": [
        "Write a MidJourney prompt for this image.",
        "Write a MidJourney prompt for this image within {word_count} words.",
        "Write a {length} MidJourney prompt for this image.",
    ],
}

def get_image_features(input_image: Image.Image, clip_model, image_adapter=None):
    """Extract features from image using CLIP"""
    image = input_image.resize((384, 384), Image.LANCZOS)
    pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
    pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])

    with torch.no_grad():
        vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)

        if image_adapter is not None:
            embedded_images = image_adapter(vision_outputs.hidden_states)
            return embedded_images
        else:
            return vision_outputs.last_hidden_state


def generate_caption(input_image: Image.Image,
                    caption_type: str = "Descriptive",
                    caption_length: str = "long",
                    extra_options: list = None,
                    name_input: str = "",
                    custom_prompt: str = "",
                    clip_model=None,
                    image_adapter=None):
    """
    Generate caption for an image using Gemini API. No Bullet points, proper punctuation, and no extra information.

    Args:
        input_image: PIL Image object
        caption_type: Type of caption ("Descriptive", "Training Prompt", "MidJourney")
        caption_length: Length specification ("any", "short", "long", etc. or number as string)
        extra_options: List of extra options
        name_input: Name to use for person/character in image
        custom_prompt: Custom prompt to override default settings
        clip_model: CLIP model (optional, for compatibility)
        image_adapter: Image adapter model (optional, for compatibility)

    Returns:
        tuple: (generated_caption)
    """
    if gemini_model is None:
        return "Error: Gemini API key not configured", "Please set GEMINI_API_KEY environment variable"

    if input_image is None:
        return "Error: No image provided", "Please provide an image"

    if extra_options is None:
        extra_options = []

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    length = None if caption_length == "any" else caption_length

    if isinstance(length, str):
        try:
            length = int(length)
        except ValueError:
            pass

    if length is None:
        map_idx = 0
    elif isinstance(length, int):
        map_idx = 1
    elif isinstance(length, str):
        map_idx = 2
    else:
        raise ValueError(f"Invalid caption length: {length}")

    prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]

    if len(extra_options) > 0:
        prompt_str += " " + " ".join(extra_options)

    prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)

    if custom_prompt.strip() != "":
        prompt_str = custom_prompt.strip()

    try:
        if clip_model is not None:
            image_features = get_image_features(input_image, clip_model, image_adapter)
            print(f"Extracted image features shape: {image_features.shape if hasattr(image_features, 'shape') else 'N/A'}")

        full_prompt = f"""You are a helpful image captioner.

{prompt_str}

Please analyze the provided image and generate a caption according to the instructions above. Just only the caption text, no additional information."""

        response = gemini_model.generate_content([full_prompt, input_image])

        if response.text:
            caption = response.text.strip()
        else:
            caption = "Failed to generate caption"

    except Exception as e:
        print(f"Error generating caption: {str(e)}")
        return prompt_str, f"Error: {str(e)}"

    return prompt_str, caption

def caption_image_from_path(image_path: str, **kwargs):
    """Caption an image from file path"""
    image = Image.open(image_path)
    return generate_caption(image, **kwargs)


def caption_image_simple(image_path: str, caption_type: str = "Descriptive"):
    """Simple interface to caption an image"""
    image = Image.open(image_path)
    prompt_used, caption = generate_caption(image, caption_type=caption_type)
    print(f"Caption: {caption}")
    return caption