Update src/app.py
Browse files- src/app.py +14 -26
src/app.py
CHANGED
@@ -38,9 +38,9 @@ st.set_page_config(
|
|
38 |
MODEL_CONFIGS = {
|
39 |
"BLIP": {
|
40 |
"name": "BLIP",
|
41 |
-
"icon": "⭐",
|
42 |
"description": "BLIP (Bootstrapping Language-Image Pre-training) is designed to learn vision-language representation from noisy web data. It excels at generating detailed and accurate image descriptions.",
|
43 |
-
"generate_params": {"max_length": 50, "num_beams": 5, "min_length": 10, "top_p": 0.9, "repetition_penalty": 1.5}
|
44 |
},
|
45 |
"ViT-GPT2": {
|
46 |
"name": "ViT-GPT2",
|
@@ -64,9 +64,8 @@ MODEL_CONFIGS = {
|
|
64 |
# ......................... LOADING FUNCTIONS .....................................
|
65 |
@st.cache_resource
|
66 |
def load_blip_model():
|
67 |
-
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-
|
68 |
-
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-
|
69 |
-
if torch.cuda.is_available(): model = model.to("cuda")
|
70 |
return model, processor
|
71 |
|
72 |
@st.cache_resource
|
@@ -74,21 +73,18 @@ def load_vit_gpt2_model():
|
|
74 |
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
75 |
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
76 |
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
77 |
-
if torch.cuda.is_available(): model = model.to("cuda")
|
78 |
return model, feature_extractor, tokenizer
|
79 |
|
80 |
@st.cache_resource
|
81 |
def load_git_model():
|
82 |
processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
83 |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
|
84 |
-
if torch.cuda.is_available(): model = model.to("cuda")
|
85 |
return model, processor
|
86 |
|
87 |
@st.cache_resource
|
88 |
def load_clip_model():
|
89 |
-
processor = CLIPProcessor.from_pretrained("openai/clip-vit-
|
90 |
-
model = CLIPModel.from_pretrained("openai/clip-vit-
|
91 |
-
if torch.cuda.is_available(): model = model.to("cuda")
|
92 |
return model, processor
|
93 |
|
94 |
# ......................... IMAGE PROCESSING ...............................
|
@@ -132,9 +128,7 @@ def generate_caption(image, model_name, models_data):
|
|
132 |
|
133 |
def get_blip_caption(image, model, processor):
|
134 |
try:
|
135 |
-
inputs = processor(image, return_tensors="pt")
|
136 |
-
if torch.cuda.is_available():
|
137 |
-
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
138 |
output = model.generate(**inputs, **MODEL_CONFIGS["BLIP"]["generate_params"])
|
139 |
caption = processor.decode(output[0], skip_special_tokens=True)
|
140 |
return caption
|
@@ -143,10 +137,12 @@ def get_blip_caption(image, model, processor):
|
|
143 |
|
144 |
def get_vit_gpt2_caption(image, model, feature_extractor, tokenizer):
|
145 |
try:
|
146 |
-
inputs = feature_extractor(images=image, return_tensors="pt")
|
147 |
-
|
148 |
-
inputs
|
149 |
-
|
|
|
|
|
150 |
caption = tokenizer.decode(output[0], skip_special_tokens=True)
|
151 |
return caption
|
152 |
except Exception as e:
|
@@ -154,9 +150,7 @@ def get_vit_gpt2_caption(image, model, feature_extractor, tokenizer):
|
|
154 |
|
155 |
def get_git_caption(image, model, processor):
|
156 |
try:
|
157 |
-
inputs = processor(images=image, return_tensors="pt")
|
158 |
-
if torch.cuda.is_available():
|
159 |
-
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
160 |
output = model.generate(**inputs, **MODEL_CONFIGS["GIT"]["generate_params"])
|
161 |
caption = processor.decode(output[0], skip_special_tokens=True)
|
162 |
return caption
|
@@ -190,22 +184,16 @@ STYLE_ATTRIBUTES = [
|
|
190 |
def get_clip_caption(image, model, processor):
|
191 |
try:
|
192 |
content_inputs = processor(text=CONTENT_CATEGORIES, images=image, return_tensors="pt", padding=True)
|
193 |
-
if torch.cuda.is_available():
|
194 |
-
content_inputs = {k: v.to("cuda") for k, v in content_inputs.items() if torch.is_tensor(v)}
|
195 |
content_outputs = model(**content_inputs)
|
196 |
content_probs = content_outputs.logits_per_image.softmax(dim=1)[0]
|
197 |
top_content_probs, top_content_indices = torch.topk(content_probs, 2)
|
198 |
|
199 |
scene_inputs = processor(text=SCENE_ATTRIBUTES, images=image, return_tensors="pt", padding=True)
|
200 |
-
if torch.cuda.is_available():
|
201 |
-
scene_inputs = {k: v.to("cuda") for k, v in scene_inputs.items() if torch.is_tensor(v)}
|
202 |
scene_outputs = model(**scene_inputs)
|
203 |
scene_probs = scene_outputs.logits_per_image.softmax(dim=1)[0]
|
204 |
top_scene_probs, top_scene_indices = torch.topk(scene_probs, 2)
|
205 |
|
206 |
style_inputs = processor(text=STYLE_ATTRIBUTES, images=image, return_tensors="pt", padding=True)
|
207 |
-
if torch.cuda.is_available():
|
208 |
-
style_inputs = {k: v.to("cuda") for k, v in style_inputs.items() if torch.is_tensor(v)}
|
209 |
style_outputs = model(**style_inputs)
|
210 |
style_probs = style_outputs.logits_per_image.softmax(dim=1)[0]
|
211 |
top_style_probs, top_style_indices = torch.topk(style_probs, 1)
|
|
|
38 |
MODEL_CONFIGS = {
|
39 |
"BLIP": {
|
40 |
"name": "BLIP",
|
41 |
+
"icon": "⭐",
|
42 |
"description": "BLIP (Bootstrapping Language-Image Pre-training) is designed to learn vision-language representation from noisy web data. It excels at generating detailed and accurate image descriptions.",
|
43 |
+
"generate_params": {"max_length": 50, "num_beams": 5, "min_length": 10, "do_sample": True, "top_p": 0.9, "repetition_penalty": 1.5} # Added do_sample=True
|
44 |
},
|
45 |
"ViT-GPT2": {
|
46 |
"name": "ViT-GPT2",
|
|
|
64 |
# ......................... LOADING FUNCTIONS .....................................
|
65 |
@st.cache_resource
|
66 |
def load_blip_model():
|
67 |
+
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") # Changed to base model
|
68 |
+
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
|
|
69 |
return model, processor
|
70 |
|
71 |
@st.cache_resource
|
|
|
73 |
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
74 |
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
75 |
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
|
|
76 |
return model, feature_extractor, tokenizer
|
77 |
|
78 |
@st.cache_resource
|
79 |
def load_git_model():
|
80 |
processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
81 |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
|
|
|
82 |
return model, processor
|
83 |
|
84 |
@st.cache_resource
|
85 |
def load_clip_model():
|
86 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Changed to smaller model
|
87 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
88 |
return model, processor
|
89 |
|
90 |
# ......................... IMAGE PROCESSING ...............................
|
|
|
128 |
|
129 |
def get_blip_caption(image, model, processor):
|
130 |
try:
|
131 |
+
inputs = processor(images=image, return_tensors="pt", padding=True, truncation=True)
|
|
|
|
|
132 |
output = model.generate(**inputs, **MODEL_CONFIGS["BLIP"]["generate_params"])
|
133 |
caption = processor.decode(output[0], skip_special_tokens=True)
|
134 |
return caption
|
|
|
137 |
|
138 |
def get_vit_gpt2_caption(image, model, feature_extractor, tokenizer):
|
139 |
try:
|
140 |
+
inputs = feature_extractor(images=image, return_tensors="pt", padding=True)
|
141 |
+
output = model.generate(
|
142 |
+
pixel_values=inputs.pixel_values,
|
143 |
+
**MODEL_CONFIGS["ViT-GPT2"]["generate_params"],
|
144 |
+
attention_mask=inputs.attention_mask if hasattr(inputs, "attention_mask") else None
|
145 |
+
)
|
146 |
caption = tokenizer.decode(output[0], skip_special_tokens=True)
|
147 |
return caption
|
148 |
except Exception as e:
|
|
|
150 |
|
151 |
def get_git_caption(image, model, processor):
|
152 |
try:
|
153 |
+
inputs = processor(images=image, return_tensors="pt", padding=True)
|
|
|
|
|
154 |
output = model.generate(**inputs, **MODEL_CONFIGS["GIT"]["generate_params"])
|
155 |
caption = processor.decode(output[0], skip_special_tokens=True)
|
156 |
return caption
|
|
|
184 |
def get_clip_caption(image, model, processor):
|
185 |
try:
|
186 |
content_inputs = processor(text=CONTENT_CATEGORIES, images=image, return_tensors="pt", padding=True)
|
|
|
|
|
187 |
content_outputs = model(**content_inputs)
|
188 |
content_probs = content_outputs.logits_per_image.softmax(dim=1)[0]
|
189 |
top_content_probs, top_content_indices = torch.topk(content_probs, 2)
|
190 |
|
191 |
scene_inputs = processor(text=SCENE_ATTRIBUTES, images=image, return_tensors="pt", padding=True)
|
|
|
|
|
192 |
scene_outputs = model(**scene_inputs)
|
193 |
scene_probs = scene_outputs.logits_per_image.softmax(dim=1)[0]
|
194 |
top_scene_probs, top_scene_indices = torch.topk(scene_probs, 2)
|
195 |
|
196 |
style_inputs = processor(text=STYLE_ATTRIBUTES, images=image, return_tensors="pt", padding=True)
|
|
|
|
|
197 |
style_outputs = model(**style_inputs)
|
198 |
style_probs = style_outputs.logits_per_image.softmax(dim=1)[0]
|
199 |
top_style_probs, top_style_indices = torch.topk(style_probs, 1)
|