Upload 7 files
Browse files- README.md +12 -0
- app.py +153 -0
- generate_caption.py +151 -0
- image_adapter.py +111 -0
- initializer.py +13 -0
- model_initial.py +41 -0
- requirements.txt +19 -3
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Image Prompting-and-Captioning
|
3 |
+
emoji: 🖼️
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: blue
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.45.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Make a virtual environment with compulsory GPU!!
|
2 |
+
# python -m venv .venv
|
3 |
+
# Activate the virtual environment
|
4 |
+
# Windows: .venv\Scripts\activate
|
5 |
+
# Linux/Mac: source .venv/bin/activate
|
6 |
+
|
7 |
+
# Install packages
|
8 |
+
# pip install -r requirements.txt
|
9 |
+
|
10 |
+
import streamlit as st
|
11 |
+
from generate_caption import generate_caption
|
12 |
+
from PIL import Image
|
13 |
+
import io
|
14 |
+
|
15 |
+
# Set page config
|
16 |
+
st.set_page_config(
|
17 |
+
page_title="AI Image Caption & Prompt Generator",
|
18 |
+
page_icon="🖼️",
|
19 |
+
layout="wide"
|
20 |
+
)
|
21 |
+
|
22 |
+
# Title and description
|
23 |
+
st.title("🖼️ AI Image Caption & Prompt Generator")
|
24 |
+
|
25 |
+
# Create two columns for layout
|
26 |
+
col1, col2 = st.columns([1, 1])
|
27 |
+
|
28 |
+
with col1:
|
29 |
+
st.header("📤 Upload & Configure")
|
30 |
+
|
31 |
+
# File uploader
|
32 |
+
uploaded_file = st.file_uploader(
|
33 |
+
"Choose an image file",
|
34 |
+
type=['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'webp'],
|
35 |
+
help="Supported formats: PNG, JPG, JPEG, GIF, BMP, TIFF, WebP"
|
36 |
+
)
|
37 |
+
|
38 |
+
# Caption options
|
39 |
+
st.subheader("Caption Options")
|
40 |
+
|
41 |
+
caption_type = st.selectbox(
|
42 |
+
"Caption Type",
|
43 |
+
options=["MidJourney", "Descriptive", "Training Prompt"],
|
44 |
+
index=0,
|
45 |
+
help="Choose the style of caption you want to generate"
|
46 |
+
)
|
47 |
+
|
48 |
+
caption_length = st.selectbox(
|
49 |
+
"Caption Length",
|
50 |
+
options=["short", "any", "long"],
|
51 |
+
index=0,
|
52 |
+
help="Select the desired length of the caption"
|
53 |
+
)
|
54 |
+
|
55 |
+
# Generate button
|
56 |
+
generate_btn = st.button("🎯 Generate Caption", type="primary", use_container_width=True)
|
57 |
+
|
58 |
+
with col2:
|
59 |
+
st.header("Preview & Results")
|
60 |
+
|
61 |
+
if uploaded_file is not None:
|
62 |
+
# Display uploaded image
|
63 |
+
image = Image.open(uploaded_file)
|
64 |
+
st.image(image, caption="Uploaded Image", use_container_width=True)
|
65 |
+
|
66 |
+
# Generate caption when button is clicked
|
67 |
+
if generate_btn:
|
68 |
+
with st.spinner("Generating caption... This may take a moment."):
|
69 |
+
try:
|
70 |
+
# Generate caption
|
71 |
+
prompt_used, caption = generate_caption(
|
72 |
+
image,
|
73 |
+
caption_type=caption_type,
|
74 |
+
caption_length=caption_length,
|
75 |
+
extra_options=None
|
76 |
+
)
|
77 |
+
|
78 |
+
# Display results
|
79 |
+
st.success("Caption generated successfully!")
|
80 |
+
|
81 |
+
# Caption result
|
82 |
+
st.subheader("📝 Generated Caption")
|
83 |
+
st.write(f"{caption}")
|
84 |
+
|
85 |
+
# Copy to clipboard button
|
86 |
+
st.code(caption, language=None)
|
87 |
+
|
88 |
+
# Additional info
|
89 |
+
with st.expander("ℹ️ Generation Details"):
|
90 |
+
st.write(f"**Caption Type:** {caption_type}")
|
91 |
+
st.write(f"**Caption Length:** {caption_length}")
|
92 |
+
if prompt_used:
|
93 |
+
st.write(f"**Prompt Used:** {prompt_used}")
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
st.error(f"Error generating caption: {str(e)}")
|
97 |
+
st.info("Please make sure you have installed all required dependencies and your GPU is properly configured.")
|
98 |
+
|
99 |
+
else:
|
100 |
+
st.markdown("""
|
101 |
+
### How to use:
|
102 |
+
1. Upload an image using the file uploader
|
103 |
+
2. Select your preferred caption type and length
|
104 |
+
3. Click 'Generate Caption' to create your AI caption
|
105 |
+
|
106 |
+
### Caption Types:
|
107 |
+
- **MidJourney**: Optimized for AI art generation prompts
|
108 |
+
- **Descriptive**: Detailed description of the image content
|
109 |
+
- **Training Prompt**: Formatted for AI model training
|
110 |
+
""")
|
111 |
+
|
112 |
+
# Sidebar with additional information
|
113 |
+
with st.sidebar:
|
114 |
+
st.header("🔧 System Requirements")
|
115 |
+
st.markdown("""
|
116 |
+
**Required Setup:**
|
117 |
+
- GPU-enabled environment
|
118 |
+
- Virtual environment activated
|
119 |
+
- All dependencies installed via `requirements.txt`
|
120 |
+
|
121 |
+
**Supported Image Formats:**
|
122 |
+
- PNG, JPG, JPEG
|
123 |
+
- GIF, BMP, TIFF, WebP
|
124 |
+
""")
|
125 |
+
|
126 |
+
st.header("💡 Tips")
|
127 |
+
st.markdown("""
|
128 |
+
- Higher quality images produce better captions
|
129 |
+
- Different caption types serve different purposes
|
130 |
+
- Short captions are more focused, long ones more detailed
|
131 |
+
""")
|
132 |
+
|
133 |
+
st.header("⚙️ Setup Instructions")
|
134 |
+
with st.expander("Click to view setup commands"):
|
135 |
+
st.code("""
|
136 |
+
# Create virtual environment
|
137 |
+
python -m venv .venv
|
138 |
+
|
139 |
+
# Activate virtual environment
|
140 |
+
# Windows:
|
141 |
+
.venv\\Scripts\\activate
|
142 |
+
# Linux/Mac:
|
143 |
+
source .venv/bin/activate
|
144 |
+
|
145 |
+
# Install dependencies
|
146 |
+
pip install -r requirements.txt
|
147 |
+
|
148 |
+
# Run Streamlit app
|
149 |
+
streamlit run app.py
|
150 |
+
""", language="bash")
|
151 |
+
|
152 |
+
# Footer
|
153 |
+
st.markdown("Built by [Aditya Singh]")
|
generate_caption.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import torchvision.transforms.functional as TVF
|
4 |
+
import google.generativeai as genai
|
5 |
+
import os
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
# load_dotenv()
|
9 |
+
# GEMINI_API_KEY = os.getenv('GOOGLE_API_KEY')
|
10 |
+
GEMINI_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
11 |
+
|
12 |
+
if GEMINI_API_KEY:
|
13 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
14 |
+
gemini_model = genai.GenerativeModel('gemini-1.5-flash')
|
15 |
+
else:
|
16 |
+
print("Warning: GOOGLE_API_KEY not found in environment variables")
|
17 |
+
gemini_model = None
|
18 |
+
|
19 |
+
CAPTION_TYPE_MAP = {
|
20 |
+
"Descriptive": [
|
21 |
+
"Write a descriptive caption for this image in a formal tone.",
|
22 |
+
"Write a descriptive caption for this image in a formal tone within {word_count} words.",
|
23 |
+
"Write a {length} descriptive caption for this image in a formal tone.",
|
24 |
+
],
|
25 |
+
"Training Prompt": [
|
26 |
+
"Write a stable diffusion prompt for this image.",
|
27 |
+
"Write a stable diffusion prompt for this image within {word_count} words.",
|
28 |
+
"Write a {length} stable diffusion prompt for this image.",
|
29 |
+
],
|
30 |
+
"MidJourney": [
|
31 |
+
"Write a MidJourney prompt for this image.",
|
32 |
+
"Write a MidJourney prompt for this image within {word_count} words.",
|
33 |
+
"Write a {length} MidJourney prompt for this image.",
|
34 |
+
],
|
35 |
+
}
|
36 |
+
|
37 |
+
def get_image_features(input_image: Image.Image, clip_model, image_adapter=None):
|
38 |
+
"""Extract features from image using CLIP"""
|
39 |
+
image = input_image.resize((384, 384), Image.LANCZOS)
|
40 |
+
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
41 |
+
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
42 |
+
|
43 |
+
with torch.no_grad():
|
44 |
+
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
45 |
+
|
46 |
+
if image_adapter is not None:
|
47 |
+
embedded_images = image_adapter(vision_outputs.hidden_states)
|
48 |
+
return embedded_images
|
49 |
+
else:
|
50 |
+
return vision_outputs.last_hidden_state
|
51 |
+
|
52 |
+
|
53 |
+
def generate_caption(input_image: Image.Image,
|
54 |
+
caption_type: str = "Descriptive",
|
55 |
+
caption_length: str = "long",
|
56 |
+
extra_options: list = None,
|
57 |
+
name_input: str = "",
|
58 |
+
custom_prompt: str = "",
|
59 |
+
clip_model=None,
|
60 |
+
image_adapter=None):
|
61 |
+
"""
|
62 |
+
Generate caption for an image using Gemini API.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
input_image: PIL Image object
|
66 |
+
caption_type: Type of caption ("Descriptive", "Training Prompt", "MidJourney")
|
67 |
+
caption_length: Length specification ("any", "short", "long", etc. or number as string)
|
68 |
+
extra_options: List of extra options
|
69 |
+
name_input: Name to use for person/character in image
|
70 |
+
custom_prompt: Custom prompt to override default settings
|
71 |
+
clip_model: CLIP model (optional, for compatibility)
|
72 |
+
image_adapter: Image adapter model (optional, for compatibility)
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
tuple: (generated_caption)
|
76 |
+
"""
|
77 |
+
if gemini_model is None:
|
78 |
+
return "Error: Gemini API key not configured", "Please set GEMINI_API_KEY environment variable"
|
79 |
+
|
80 |
+
if input_image is None:
|
81 |
+
return "Error: No image provided", "Please provide an image"
|
82 |
+
|
83 |
+
if extra_options is None:
|
84 |
+
extra_options = []
|
85 |
+
|
86 |
+
if torch.cuda.is_available():
|
87 |
+
torch.cuda.empty_cache()
|
88 |
+
|
89 |
+
length = None if caption_length == "any" else caption_length
|
90 |
+
|
91 |
+
if isinstance(length, str):
|
92 |
+
try:
|
93 |
+
length = int(length)
|
94 |
+
except ValueError:
|
95 |
+
pass
|
96 |
+
|
97 |
+
if length is None:
|
98 |
+
map_idx = 0
|
99 |
+
elif isinstance(length, int):
|
100 |
+
map_idx = 1
|
101 |
+
elif isinstance(length, str):
|
102 |
+
map_idx = 2
|
103 |
+
else:
|
104 |
+
raise ValueError(f"Invalid caption length: {length}")
|
105 |
+
|
106 |
+
prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
|
107 |
+
|
108 |
+
if len(extra_options) > 0:
|
109 |
+
prompt_str += " " + " ".join(extra_options)
|
110 |
+
|
111 |
+
prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
|
112 |
+
|
113 |
+
if custom_prompt.strip() != "":
|
114 |
+
prompt_str = custom_prompt.strip()
|
115 |
+
|
116 |
+
try:
|
117 |
+
if clip_model is not None:
|
118 |
+
image_features = get_image_features(input_image, clip_model, image_adapter)
|
119 |
+
print(f"Extracted image features shape: {image_features.shape if hasattr(image_features, 'shape') else 'N/A'}")
|
120 |
+
|
121 |
+
full_prompt = f"""You are a helpful image captioner.
|
122 |
+
|
123 |
+
{prompt_str}
|
124 |
+
|
125 |
+
Please analyze the provided image and generate a caption according to the instructions above. Just only the caption text, no additional information."""
|
126 |
+
|
127 |
+
response = gemini_model.generate_content([full_prompt, input_image])
|
128 |
+
|
129 |
+
if response.text:
|
130 |
+
caption = response.text.strip()
|
131 |
+
else:
|
132 |
+
caption = "Failed to generate caption"
|
133 |
+
|
134 |
+
except Exception as e:
|
135 |
+
print(f"Error generating caption: {str(e)}")
|
136 |
+
return prompt_str, f"Error: {str(e)}"
|
137 |
+
|
138 |
+
return prompt_str, caption
|
139 |
+
|
140 |
+
def caption_image_from_path(image_path: str, **kwargs):
|
141 |
+
"""Caption an image from file path"""
|
142 |
+
image = Image.open(image_path)
|
143 |
+
return generate_caption(image, **kwargs)
|
144 |
+
|
145 |
+
|
146 |
+
def caption_image_simple(image_path: str, caption_type: str = "Descriptive"):
|
147 |
+
"""Simple interface to caption an image"""
|
148 |
+
image = Image.open(image_path)
|
149 |
+
prompt_used, caption = generate_caption(image, caption_type=caption_type)
|
150 |
+
print(f"Caption: {caption}")
|
151 |
+
return caption
|
image_adapter.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from transformers import AutoModel, AutoProcessor
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
import torch.amp.autocast_mode
|
6 |
+
from PIL import Image
|
7 |
+
import os
|
8 |
+
import torchvision.transforms.functional as TVF
|
9 |
+
import base64
|
10 |
+
import io
|
11 |
+
|
12 |
+
|
13 |
+
class ImageAdapter(nn.Module):
|
14 |
+
def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
|
15 |
+
super().__init__()
|
16 |
+
self.deep_extract = deep_extract
|
17 |
+
|
18 |
+
if self.deep_extract:
|
19 |
+
input_features = input_features * 5
|
20 |
+
|
21 |
+
self.linear1 = nn.Linear(input_features, output_features)
|
22 |
+
self.activation = nn.GELU()
|
23 |
+
self.linear2 = nn.Linear(output_features, output_features)
|
24 |
+
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
|
25 |
+
self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
|
26 |
+
|
27 |
+
self.other_tokens = nn.Embedding(3, output_features)
|
28 |
+
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
|
29 |
+
|
30 |
+
def forward(self, vision_outputs: torch.Tensor):
|
31 |
+
if self.deep_extract:
|
32 |
+
x = torch.concat((
|
33 |
+
vision_outputs[-2],
|
34 |
+
vision_outputs[3],
|
35 |
+
vision_outputs[7],
|
36 |
+
vision_outputs[13],
|
37 |
+
vision_outputs[20],
|
38 |
+
), dim=-1)
|
39 |
+
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
|
40 |
+
assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
|
41 |
+
else:
|
42 |
+
x = vision_outputs[-2]
|
43 |
+
|
44 |
+
x = self.ln1(x)
|
45 |
+
|
46 |
+
if self.pos_emb is not None:
|
47 |
+
assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
|
48 |
+
x = x + self.pos_emb
|
49 |
+
|
50 |
+
x = self.linear1(x)
|
51 |
+
x = self.activation(x)
|
52 |
+
x = self.linear2(x)
|
53 |
+
|
54 |
+
other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
|
55 |
+
assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
|
56 |
+
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
|
57 |
+
|
58 |
+
return x
|
59 |
+
|
60 |
+
def get_eot_embedding(self):
|
61 |
+
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
|
62 |
+
|
63 |
+
class ImageAdapter(nn.Module):
|
64 |
+
def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
|
65 |
+
super().__init__()
|
66 |
+
self.deep_extract = deep_extract
|
67 |
+
|
68 |
+
if self.deep_extract:
|
69 |
+
input_features = input_features * 5
|
70 |
+
|
71 |
+
self.linear1 = nn.Linear(input_features, output_features)
|
72 |
+
self.activation = nn.GELU()
|
73 |
+
self.linear2 = nn.Linear(output_features, output_features)
|
74 |
+
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
|
75 |
+
self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
|
76 |
+
|
77 |
+
self.other_tokens = nn.Embedding(3, output_features)
|
78 |
+
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
|
79 |
+
|
80 |
+
def forward(self, vision_outputs: torch.Tensor):
|
81 |
+
if self.deep_extract:
|
82 |
+
x = torch.concat((
|
83 |
+
vision_outputs[-2],
|
84 |
+
vision_outputs[3],
|
85 |
+
vision_outputs[7],
|
86 |
+
vision_outputs[13],
|
87 |
+
vision_outputs[20],
|
88 |
+
), dim=-1)
|
89 |
+
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
|
90 |
+
assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
|
91 |
+
else:
|
92 |
+
x = vision_outputs[-2]
|
93 |
+
|
94 |
+
x = self.ln1(x)
|
95 |
+
|
96 |
+
if self.pos_emb is not None:
|
97 |
+
assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
|
98 |
+
x = x + self.pos_emb
|
99 |
+
|
100 |
+
x = self.linear1(x)
|
101 |
+
x = self.activation(x)
|
102 |
+
x = self.linear2(x)
|
103 |
+
|
104 |
+
other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
|
105 |
+
assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
|
106 |
+
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
|
107 |
+
|
108 |
+
return x
|
109 |
+
|
110 |
+
def get_eot_embedding(self):
|
111 |
+
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
|
initializer.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model_initial import initialize_models
|
2 |
+
from generate_caption import generate_caption
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
if __name__ == "__main__" or "get_ipython" in globals():
|
6 |
+
print("Initializing models...")
|
7 |
+
try:
|
8 |
+
clip_model, image_adapter = initialize_models()
|
9 |
+
print("Models initialized successfully!")
|
10 |
+
except Exception as e:
|
11 |
+
print(f"Error initializing models: {e}")
|
12 |
+
print("You can still use the basic caption functionality with Gemini API only")
|
13 |
+
clip_model, image_adapter = None, None
|
model_initial.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoProcessor
|
2 |
+
from pathlib import Path
|
3 |
+
from image_adapter import ImageAdapter
|
4 |
+
import torch
|
5 |
+
|
6 |
+
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
7 |
+
CHECKPOINT_PATH = Path("Adieee5/Image-captioning")
|
8 |
+
# CHECKPOINT_PATH = Path("cheackpoints")
|
9 |
+
|
10 |
+
|
11 |
+
def initialize_models():
|
12 |
+
"""Initialize and load all models"""
|
13 |
+
print("Loading CLIP")
|
14 |
+
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
|
15 |
+
clip_model = AutoModel.from_pretrained(CLIP_PATH)
|
16 |
+
clip_model = clip_model.vision_model
|
17 |
+
|
18 |
+
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
19 |
+
print("Loading VLM's custom vision model")
|
20 |
+
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
|
21 |
+
checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
|
22 |
+
clip_model.load_state_dict(checkpoint)
|
23 |
+
del checkpoint
|
24 |
+
else:
|
25 |
+
print("Custom CLIP weights not found, using default weights")
|
26 |
+
|
27 |
+
clip_model.eval()
|
28 |
+
clip_model.requires_grad_(False)
|
29 |
+
clip_model.to("cpu")
|
30 |
+
|
31 |
+
image_adapter = None
|
32 |
+
if (CHECKPOINT_PATH / "image_presenter.pt").exists():
|
33 |
+
print("Loading image adapter")
|
34 |
+
image_adapter = ImageAdapter(clip_model.config.hidden_size, 4096, False, False, 38, False)
|
35 |
+
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_presenter.pt", map_location="cpu"))
|
36 |
+
image_adapter.eval()
|
37 |
+
image_adapter.to("cpu")
|
38 |
+
else:
|
39 |
+
print("Image adapter not found, will use CLIP features directly")
|
40 |
+
|
41 |
+
return clip_model, image_adapter
|
requirements.txt
CHANGED
@@ -1,3 +1,19 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
accelerate
|
3 |
+
torch
|
4 |
+
transformers
|
5 |
+
sentencepiece
|
6 |
+
peft
|
7 |
+
torchvision
|
8 |
+
protobuf
|
9 |
+
google-ai-generativelanguage==0.4.0
|
10 |
+
google-api-core==2.24.2
|
11 |
+
google-auth==2.38.0
|
12 |
+
google-generativeai==0.4.1
|
13 |
+
langchain==0.1.13
|
14 |
+
langchain-community==0.0.29
|
15 |
+
langchain-google-genai==0.0.11
|
16 |
+
google-ai-generativelanguage==0.4.0
|
17 |
+
python-dotenv
|
18 |
+
streamlit
|
19 |
+
watchdog
|