Adieee5 commited on
Commit
19dc712
·
verified ·
1 Parent(s): e52bf46

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +12 -0
  2. app.py +153 -0
  3. generate_caption.py +151 -0
  4. image_adapter.py +111 -0
  5. initializer.py +13 -0
  6. model_initial.py +41 -0
  7. requirements.txt +19 -3
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Image Prompting-and-Captioning
3
+ emoji: 🖼️
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.45.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Make a virtual environment with compulsory GPU!!
2
+ # python -m venv .venv
3
+ # Activate the virtual environment
4
+ # Windows: .venv\Scripts\activate
5
+ # Linux/Mac: source .venv/bin/activate
6
+
7
+ # Install packages
8
+ # pip install -r requirements.txt
9
+
10
+ import streamlit as st
11
+ from generate_caption import generate_caption
12
+ from PIL import Image
13
+ import io
14
+
15
+ # Set page config
16
+ st.set_page_config(
17
+ page_title="AI Image Caption & Prompt Generator",
18
+ page_icon="🖼️",
19
+ layout="wide"
20
+ )
21
+
22
+ # Title and description
23
+ st.title("🖼️ AI Image Caption & Prompt Generator")
24
+
25
+ # Create two columns for layout
26
+ col1, col2 = st.columns([1, 1])
27
+
28
+ with col1:
29
+ st.header("📤 Upload & Configure")
30
+
31
+ # File uploader
32
+ uploaded_file = st.file_uploader(
33
+ "Choose an image file",
34
+ type=['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'webp'],
35
+ help="Supported formats: PNG, JPG, JPEG, GIF, BMP, TIFF, WebP"
36
+ )
37
+
38
+ # Caption options
39
+ st.subheader("Caption Options")
40
+
41
+ caption_type = st.selectbox(
42
+ "Caption Type",
43
+ options=["MidJourney", "Descriptive", "Training Prompt"],
44
+ index=0,
45
+ help="Choose the style of caption you want to generate"
46
+ )
47
+
48
+ caption_length = st.selectbox(
49
+ "Caption Length",
50
+ options=["short", "any", "long"],
51
+ index=0,
52
+ help="Select the desired length of the caption"
53
+ )
54
+
55
+ # Generate button
56
+ generate_btn = st.button("🎯 Generate Caption", type="primary", use_container_width=True)
57
+
58
+ with col2:
59
+ st.header("Preview & Results")
60
+
61
+ if uploaded_file is not None:
62
+ # Display uploaded image
63
+ image = Image.open(uploaded_file)
64
+ st.image(image, caption="Uploaded Image", use_container_width=True)
65
+
66
+ # Generate caption when button is clicked
67
+ if generate_btn:
68
+ with st.spinner("Generating caption... This may take a moment."):
69
+ try:
70
+ # Generate caption
71
+ prompt_used, caption = generate_caption(
72
+ image,
73
+ caption_type=caption_type,
74
+ caption_length=caption_length,
75
+ extra_options=None
76
+ )
77
+
78
+ # Display results
79
+ st.success("Caption generated successfully!")
80
+
81
+ # Caption result
82
+ st.subheader("📝 Generated Caption")
83
+ st.write(f"{caption}")
84
+
85
+ # Copy to clipboard button
86
+ st.code(caption, language=None)
87
+
88
+ # Additional info
89
+ with st.expander("ℹ️ Generation Details"):
90
+ st.write(f"**Caption Type:** {caption_type}")
91
+ st.write(f"**Caption Length:** {caption_length}")
92
+ if prompt_used:
93
+ st.write(f"**Prompt Used:** {prompt_used}")
94
+
95
+ except Exception as e:
96
+ st.error(f"Error generating caption: {str(e)}")
97
+ st.info("Please make sure you have installed all required dependencies and your GPU is properly configured.")
98
+
99
+ else:
100
+ st.markdown("""
101
+ ### How to use:
102
+ 1. Upload an image using the file uploader
103
+ 2. Select your preferred caption type and length
104
+ 3. Click 'Generate Caption' to create your AI caption
105
+
106
+ ### Caption Types:
107
+ - **MidJourney**: Optimized for AI art generation prompts
108
+ - **Descriptive**: Detailed description of the image content
109
+ - **Training Prompt**: Formatted for AI model training
110
+ """)
111
+
112
+ # Sidebar with additional information
113
+ with st.sidebar:
114
+ st.header("🔧 System Requirements")
115
+ st.markdown("""
116
+ **Required Setup:**
117
+ - GPU-enabled environment
118
+ - Virtual environment activated
119
+ - All dependencies installed via `requirements.txt`
120
+
121
+ **Supported Image Formats:**
122
+ - PNG, JPG, JPEG
123
+ - GIF, BMP, TIFF, WebP
124
+ """)
125
+
126
+ st.header("💡 Tips")
127
+ st.markdown("""
128
+ - Higher quality images produce better captions
129
+ - Different caption types serve different purposes
130
+ - Short captions are more focused, long ones more detailed
131
+ """)
132
+
133
+ st.header("⚙️ Setup Instructions")
134
+ with st.expander("Click to view setup commands"):
135
+ st.code("""
136
+ # Create virtual environment
137
+ python -m venv .venv
138
+
139
+ # Activate virtual environment
140
+ # Windows:
141
+ .venv\\Scripts\\activate
142
+ # Linux/Mac:
143
+ source .venv/bin/activate
144
+
145
+ # Install dependencies
146
+ pip install -r requirements.txt
147
+
148
+ # Run Streamlit app
149
+ streamlit run app.py
150
+ """, language="bash")
151
+
152
+ # Footer
153
+ st.markdown("Built by [Aditya Singh]")
generate_caption.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import torchvision.transforms.functional as TVF
4
+ import google.generativeai as genai
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ # load_dotenv()
9
+ # GEMINI_API_KEY = os.getenv('GOOGLE_API_KEY')
10
+ GEMINI_API_KEY = os.environ.get("GOOGLE_API_KEY")
11
+
12
+ if GEMINI_API_KEY:
13
+ genai.configure(api_key=GEMINI_API_KEY)
14
+ gemini_model = genai.GenerativeModel('gemini-1.5-flash')
15
+ else:
16
+ print("Warning: GOOGLE_API_KEY not found in environment variables")
17
+ gemini_model = None
18
+
19
+ CAPTION_TYPE_MAP = {
20
+ "Descriptive": [
21
+ "Write a descriptive caption for this image in a formal tone.",
22
+ "Write a descriptive caption for this image in a formal tone within {word_count} words.",
23
+ "Write a {length} descriptive caption for this image in a formal tone.",
24
+ ],
25
+ "Training Prompt": [
26
+ "Write a stable diffusion prompt for this image.",
27
+ "Write a stable diffusion prompt for this image within {word_count} words.",
28
+ "Write a {length} stable diffusion prompt for this image.",
29
+ ],
30
+ "MidJourney": [
31
+ "Write a MidJourney prompt for this image.",
32
+ "Write a MidJourney prompt for this image within {word_count} words.",
33
+ "Write a {length} MidJourney prompt for this image.",
34
+ ],
35
+ }
36
+
37
+ def get_image_features(input_image: Image.Image, clip_model, image_adapter=None):
38
+ """Extract features from image using CLIP"""
39
+ image = input_image.resize((384, 384), Image.LANCZOS)
40
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
41
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
42
+
43
+ with torch.no_grad():
44
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
45
+
46
+ if image_adapter is not None:
47
+ embedded_images = image_adapter(vision_outputs.hidden_states)
48
+ return embedded_images
49
+ else:
50
+ return vision_outputs.last_hidden_state
51
+
52
+
53
+ def generate_caption(input_image: Image.Image,
54
+ caption_type: str = "Descriptive",
55
+ caption_length: str = "long",
56
+ extra_options: list = None,
57
+ name_input: str = "",
58
+ custom_prompt: str = "",
59
+ clip_model=None,
60
+ image_adapter=None):
61
+ """
62
+ Generate caption for an image using Gemini API.
63
+
64
+ Args:
65
+ input_image: PIL Image object
66
+ caption_type: Type of caption ("Descriptive", "Training Prompt", "MidJourney")
67
+ caption_length: Length specification ("any", "short", "long", etc. or number as string)
68
+ extra_options: List of extra options
69
+ name_input: Name to use for person/character in image
70
+ custom_prompt: Custom prompt to override default settings
71
+ clip_model: CLIP model (optional, for compatibility)
72
+ image_adapter: Image adapter model (optional, for compatibility)
73
+
74
+ Returns:
75
+ tuple: (generated_caption)
76
+ """
77
+ if gemini_model is None:
78
+ return "Error: Gemini API key not configured", "Please set GEMINI_API_KEY environment variable"
79
+
80
+ if input_image is None:
81
+ return "Error: No image provided", "Please provide an image"
82
+
83
+ if extra_options is None:
84
+ extra_options = []
85
+
86
+ if torch.cuda.is_available():
87
+ torch.cuda.empty_cache()
88
+
89
+ length = None if caption_length == "any" else caption_length
90
+
91
+ if isinstance(length, str):
92
+ try:
93
+ length = int(length)
94
+ except ValueError:
95
+ pass
96
+
97
+ if length is None:
98
+ map_idx = 0
99
+ elif isinstance(length, int):
100
+ map_idx = 1
101
+ elif isinstance(length, str):
102
+ map_idx = 2
103
+ else:
104
+ raise ValueError(f"Invalid caption length: {length}")
105
+
106
+ prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
107
+
108
+ if len(extra_options) > 0:
109
+ prompt_str += " " + " ".join(extra_options)
110
+
111
+ prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
112
+
113
+ if custom_prompt.strip() != "":
114
+ prompt_str = custom_prompt.strip()
115
+
116
+ try:
117
+ if clip_model is not None:
118
+ image_features = get_image_features(input_image, clip_model, image_adapter)
119
+ print(f"Extracted image features shape: {image_features.shape if hasattr(image_features, 'shape') else 'N/A'}")
120
+
121
+ full_prompt = f"""You are a helpful image captioner.
122
+
123
+ {prompt_str}
124
+
125
+ Please analyze the provided image and generate a caption according to the instructions above. Just only the caption text, no additional information."""
126
+
127
+ response = gemini_model.generate_content([full_prompt, input_image])
128
+
129
+ if response.text:
130
+ caption = response.text.strip()
131
+ else:
132
+ caption = "Failed to generate caption"
133
+
134
+ except Exception as e:
135
+ print(f"Error generating caption: {str(e)}")
136
+ return prompt_str, f"Error: {str(e)}"
137
+
138
+ return prompt_str, caption
139
+
140
+ def caption_image_from_path(image_path: str, **kwargs):
141
+ """Caption an image from file path"""
142
+ image = Image.open(image_path)
143
+ return generate_caption(image, **kwargs)
144
+
145
+
146
+ def caption_image_simple(image_path: str, caption_type: str = "Descriptive"):
147
+ """Simple interface to caption an image"""
148
+ image = Image.open(image_path)
149
+ prompt_used, caption = generate_caption(image, caption_type=caption_type)
150
+ print(f"Caption: {caption}")
151
+ return caption
image_adapter.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import AutoModel, AutoProcessor
3
+ from pathlib import Path
4
+ import torch
5
+ import torch.amp.autocast_mode
6
+ from PIL import Image
7
+ import os
8
+ import torchvision.transforms.functional as TVF
9
+ import base64
10
+ import io
11
+
12
+
13
+ class ImageAdapter(nn.Module):
14
+ def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
15
+ super().__init__()
16
+ self.deep_extract = deep_extract
17
+
18
+ if self.deep_extract:
19
+ input_features = input_features * 5
20
+
21
+ self.linear1 = nn.Linear(input_features, output_features)
22
+ self.activation = nn.GELU()
23
+ self.linear2 = nn.Linear(output_features, output_features)
24
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
25
+ self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
26
+
27
+ self.other_tokens = nn.Embedding(3, output_features)
28
+ self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
29
+
30
+ def forward(self, vision_outputs: torch.Tensor):
31
+ if self.deep_extract:
32
+ x = torch.concat((
33
+ vision_outputs[-2],
34
+ vision_outputs[3],
35
+ vision_outputs[7],
36
+ vision_outputs[13],
37
+ vision_outputs[20],
38
+ ), dim=-1)
39
+ assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
40
+ assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
41
+ else:
42
+ x = vision_outputs[-2]
43
+
44
+ x = self.ln1(x)
45
+
46
+ if self.pos_emb is not None:
47
+ assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
48
+ x = x + self.pos_emb
49
+
50
+ x = self.linear1(x)
51
+ x = self.activation(x)
52
+ x = self.linear2(x)
53
+
54
+ other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
55
+ assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
56
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
57
+
58
+ return x
59
+
60
+ def get_eot_embedding(self):
61
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
62
+
63
+ class ImageAdapter(nn.Module):
64
+ def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
65
+ super().__init__()
66
+ self.deep_extract = deep_extract
67
+
68
+ if self.deep_extract:
69
+ input_features = input_features * 5
70
+
71
+ self.linear1 = nn.Linear(input_features, output_features)
72
+ self.activation = nn.GELU()
73
+ self.linear2 = nn.Linear(output_features, output_features)
74
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
75
+ self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
76
+
77
+ self.other_tokens = nn.Embedding(3, output_features)
78
+ self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
79
+
80
+ def forward(self, vision_outputs: torch.Tensor):
81
+ if self.deep_extract:
82
+ x = torch.concat((
83
+ vision_outputs[-2],
84
+ vision_outputs[3],
85
+ vision_outputs[7],
86
+ vision_outputs[13],
87
+ vision_outputs[20],
88
+ ), dim=-1)
89
+ assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
90
+ assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
91
+ else:
92
+ x = vision_outputs[-2]
93
+
94
+ x = self.ln1(x)
95
+
96
+ if self.pos_emb is not None:
97
+ assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
98
+ x = x + self.pos_emb
99
+
100
+ x = self.linear1(x)
101
+ x = self.activation(x)
102
+ x = self.linear2(x)
103
+
104
+ other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
105
+ assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
106
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
107
+
108
+ return x
109
+
110
+ def get_eot_embedding(self):
111
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
initializer.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model_initial import initialize_models
2
+ from generate_caption import generate_caption
3
+ from PIL import Image
4
+
5
+ if __name__ == "__main__" or "get_ipython" in globals():
6
+ print("Initializing models...")
7
+ try:
8
+ clip_model, image_adapter = initialize_models()
9
+ print("Models initialized successfully!")
10
+ except Exception as e:
11
+ print(f"Error initializing models: {e}")
12
+ print("You can still use the basic caption functionality with Gemini API only")
13
+ clip_model, image_adapter = None, None
model_initial.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoProcessor
2
+ from pathlib import Path
3
+ from image_adapter import ImageAdapter
4
+ import torch
5
+
6
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
7
+ CHECKPOINT_PATH = Path("Adieee5/Image-captioning")
8
+ # CHECKPOINT_PATH = Path("cheackpoints")
9
+
10
+
11
+ def initialize_models():
12
+ """Initialize and load all models"""
13
+ print("Loading CLIP")
14
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
15
+ clip_model = AutoModel.from_pretrained(CLIP_PATH)
16
+ clip_model = clip_model.vision_model
17
+
18
+ if (CHECKPOINT_PATH / "clip_model.pt").exists():
19
+ print("Loading VLM's custom vision model")
20
+ checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
21
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
22
+ clip_model.load_state_dict(checkpoint)
23
+ del checkpoint
24
+ else:
25
+ print("Custom CLIP weights not found, using default weights")
26
+
27
+ clip_model.eval()
28
+ clip_model.requires_grad_(False)
29
+ clip_model.to("cpu")
30
+
31
+ image_adapter = None
32
+ if (CHECKPOINT_PATH / "image_presenter.pt").exists():
33
+ print("Loading image adapter")
34
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, 4096, False, False, 38, False)
35
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_presenter.pt", map_location="cpu"))
36
+ image_adapter.eval()
37
+ image_adapter.to("cpu")
38
+ else:
39
+ print("Image adapter not found, will use CLIP features directly")
40
+
41
+ return clip_model, image_adapter
requirements.txt CHANGED
@@ -1,3 +1,19 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ accelerate
3
+ torch
4
+ transformers
5
+ sentencepiece
6
+ peft
7
+ torchvision
8
+ protobuf
9
+ google-ai-generativelanguage==0.4.0
10
+ google-api-core==2.24.2
11
+ google-auth==2.38.0
12
+ google-generativeai==0.4.1
13
+ langchain==0.1.13
14
+ langchain-community==0.0.29
15
+ langchain-google-genai==0.0.11
16
+ google-ai-generativelanguage==0.4.0
17
+ python-dotenv
18
+ streamlit
19
+ watchdog