Spaces:
Running
Running
updating new version with supabase and vlm
Browse files- app.py +3 -1
- config/settings.py +5 -2
- logic/data_utils.py +36 -4
- logic/handlers.py +87 -18
- logic/supabase_client.py +193 -0
- logic/vlm.py +440 -0
- requirements.txt +67 -12
- ui/layout.py +308 -43
- ui/main_page.py +40 -0
- ui/selection_page.py +20 -2
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
# import spacy.cli
|
2 |
# spacy.cli.download("ja_core_news_sm")
|
3 |
# spacy.cli.download("zh_core_web_sm")
|
|
|
|
|
4 |
import spacy_udpipe
|
5 |
spacy_udpipe.download("ja")
|
6 |
spacy_udpipe.download("zh")
|
@@ -16,7 +18,7 @@ metadata = load_metadata()
|
|
16 |
|
17 |
demo = build_ui(concepts, metadata, HF_API_TOKEN, HF_DATASET_NAME)
|
18 |
# demo.launch()
|
19 |
-
demo.launch(debug=False)
|
20 |
|
21 |
demo.close()
|
22 |
# gr.close_all()
|
|
|
1 |
# import spacy.cli
|
2 |
# spacy.cli.download("ja_core_news_sm")
|
3 |
# spacy.cli.download("zh_core_web_sm")
|
4 |
+
import os
|
5 |
+
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
6 |
import spacy_udpipe
|
7 |
spacy_udpipe.download("ja")
|
8 |
spacy_udpipe.download("zh")
|
|
|
18 |
|
19 |
demo = build_ui(concepts, metadata, HF_API_TOKEN, HF_DATASET_NAME)
|
20 |
# demo.launch()
|
21 |
+
demo.launch(debug=False, server_port=7861)
|
22 |
|
23 |
demo.close()
|
24 |
# gr.close_all()
|
config/settings.py
CHANGED
@@ -3,6 +3,9 @@ import os
|
|
3 |
|
4 |
load_dotenv()
|
5 |
|
6 |
-
HF_API_TOKEN = os.getenv("
|
7 |
HF_DATASET_NAME = os.getenv("HF_DATASET_NAME")
|
8 |
-
LOCAL_DS_DIRECTORY_PATH = os.getenv("LOCAL_DS_DIRECTORY_PATH")
|
|
|
|
|
|
|
|
3 |
|
4 |
load_dotenv()
|
5 |
|
6 |
+
HF_API_TOKEN = os.getenv("HF_TOKEN")
|
7 |
HF_DATASET_NAME = os.getenv("HF_DATASET_NAME")
|
8 |
+
LOCAL_DS_DIRECTORY_PATH = os.getenv("LOCAL_DS_DIRECTORY_PATH")
|
9 |
+
SUPABASE_URL: str = os.getenv("SUPABASE_URL")
|
10 |
+
SUPABASE_KEY: str = os.getenv("SUPABASE_KEY")
|
11 |
+
REDIRECT_TO_URL: str = os.getenv("REDIRECT_TO_URL")
|
logic/data_utils.py
CHANGED
@@ -8,7 +8,7 @@ import uuid
|
|
8 |
import gradio as gr
|
9 |
from PIL import Image
|
10 |
import numpy as np
|
11 |
-
|
12 |
|
13 |
def load_concepts(path="data/concepts.json"):
|
14 |
with open(path, encoding='utf-8') as f:
|
@@ -53,6 +53,9 @@ class CustomHFDatasetSaver:
|
|
53 |
self.local_ds_folder = local_ds_folder
|
54 |
os.makedirs(self.local_ds_folder, exist_ok=True)
|
55 |
|
|
|
|
|
|
|
56 |
self.data_outputs = data_outputs # list of components to read values from
|
57 |
|
58 |
# create scheduler to commit the data to the hub every x minutes
|
@@ -63,6 +66,27 @@ class CustomHFDatasetSaver:
|
|
63 |
every=1,
|
64 |
token=self.api_token,
|
65 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def validate_data(self, values_dic):
|
68 |
"""
|
@@ -166,10 +190,16 @@ class CustomHFDatasetSaver:
|
|
166 |
values_dic["id"] = f'{country}_{language}_{category}_{concept}_{current_timestamp}'
|
167 |
|
168 |
#prepare the main directory of the sample
|
169 |
-
if
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
171 |
else:
|
172 |
sample_dir = os.path.join("anonymous_users", values_dic["country"], values_dic["language"], str(uuid.uuid4()), str(current_timestamp))
|
|
|
173 |
|
174 |
os.makedirs(os.path.join(self.local_ds_folder, sample_dir), exist_ok=True)
|
175 |
|
@@ -217,6 +247,8 @@ class CustomHFDatasetSaver:
|
|
217 |
# "image_file": image_file_path_on_hub,
|
218 |
"image_url": values_dic['image_url'] or "",
|
219 |
"caption": values_dic['caption'] or "",
|
|
|
|
|
220 |
"country": values_dic['country'] or "",
|
221 |
"language": values_dic['language'] or "",
|
222 |
"category": values_dic['category'] or "",
|
@@ -227,7 +259,7 @@ class CustomHFDatasetSaver:
|
|
227 |
"category_4_concepts": values_dic.get('category_4_concepts') or [""],
|
228 |
"category_5_concepts": values_dic.get('category_5_concepts') or [""],
|
229 |
"timestamp": current_timestamp,
|
230 |
-
"username":
|
231 |
"password": values_dic['password'] or "",
|
232 |
"id": values_dic['id'],
|
233 |
"excluded": False if values_dic.get('excluded') is None else bool(values_dic.get('excluded')),
|
|
|
8 |
import gradio as gr
|
9 |
from PIL import Image
|
10 |
import numpy as np
|
11 |
+
from logic.supabase_client import auth_handler
|
12 |
|
13 |
def load_concepts(path="data/concepts.json"):
|
14 |
with open(path, encoding='utf-8') as f:
|
|
|
53 |
self.local_ds_folder = local_ds_folder
|
54 |
os.makedirs(self.local_ds_folder, exist_ok=True)
|
55 |
|
56 |
+
# Migrate any existing JSON files to include new VLM fields
|
57 |
+
self._migrate_existing()
|
58 |
+
|
59 |
self.data_outputs = data_outputs # list of components to read values from
|
60 |
|
61 |
# create scheduler to commit the data to the hub every x minutes
|
|
|
66 |
every=1,
|
67 |
token=self.api_token,
|
68 |
)
|
69 |
+
|
70 |
+
def _migrate_existing(self):
|
71 |
+
"""
|
72 |
+
Ensure all existing JSON sample files have the same schema
|
73 |
+
by adding missing keys for 'vlm_caption' and 'vlm_feedback'.
|
74 |
+
"""
|
75 |
+
for root, _, files in os.walk(self.local_ds_folder):
|
76 |
+
for fname in files:
|
77 |
+
if fname.endswith('.json'):
|
78 |
+
fpath = os.path.join(root, fname)
|
79 |
+
with open(fpath, 'r+', encoding='utf-8') as f:
|
80 |
+
data = json.load(f)
|
81 |
+
updated = False
|
82 |
+
for key in ['vlm_caption', 'vlm_feedback']:
|
83 |
+
if key not in data:
|
84 |
+
data[key] = ""
|
85 |
+
updated = True
|
86 |
+
if updated:
|
87 |
+
f.seek(0)
|
88 |
+
json.dump(data, f, indent=2)
|
89 |
+
f.truncate()
|
90 |
|
91 |
def validate_data(self, values_dic):
|
92 |
"""
|
|
|
190 |
values_dic["id"] = f'{country}_{language}_{category}_{concept}_{current_timestamp}'
|
191 |
|
192 |
#prepare the main directory of the sample
|
193 |
+
# here we check if the user is logged in or not
|
194 |
+
user_info = auth_handler.is_logged_in(values_dic.get("client", None))
|
195 |
+
print(f"User info: {user_info}")
|
196 |
+
if user_info['success']:
|
197 |
+
# sample_dir = os.path.join("logged_in_users", values_dic["country"], values_dic["language"], values_dic["username"], str(current_timestamp))
|
198 |
+
sample_dir = os.path.join("logged_in_users", values_dic["country"], values_dic["language"], user_info['email'], str(current_timestamp))
|
199 |
+
print(f"Sample directory for logged in user: {sample_dir}")
|
200 |
else:
|
201 |
sample_dir = os.path.join("anonymous_users", values_dic["country"], values_dic["language"], str(uuid.uuid4()), str(current_timestamp))
|
202 |
+
print(f"Sample directory: {sample_dir}")
|
203 |
|
204 |
os.makedirs(os.path.join(self.local_ds_folder, sample_dir), exist_ok=True)
|
205 |
|
|
|
247 |
# "image_file": image_file_path_on_hub,
|
248 |
"image_url": values_dic['image_url'] or "",
|
249 |
"caption": values_dic['caption'] or "",
|
250 |
+
"vlm_caption": values_dic['vlm_caption'] or "",
|
251 |
+
"vlm_feedback": values_dic['vlm_feedback'] or "",
|
252 |
"country": values_dic['country'] or "",
|
253 |
"language": values_dic['language'] or "",
|
254 |
"category": values_dic['category'] or "",
|
|
|
259 |
"category_4_concepts": values_dic.get('category_4_concepts') or [""],
|
260 |
"category_5_concepts": values_dic.get('category_5_concepts') or [""],
|
261 |
"timestamp": current_timestamp,
|
262 |
+
"username": user_info['email'] if user_info['success'] else "",
|
263 |
"password": values_dic['password'] or "",
|
264 |
"id": values_dic['id'],
|
265 |
"excluded": False if values_dic.get('excluded') is None else bool(values_dic.get('excluded')),
|
logic/handlers.py
CHANGED
@@ -4,6 +4,7 @@ import io
|
|
4 |
import PIL
|
5 |
import requests
|
6 |
from typing import Literal
|
|
|
7 |
|
8 |
from datasets import load_dataset, concatenate_datasets, Image
|
9 |
from data.lang2eng_map import lang2eng_mapping
|
@@ -12,6 +13,7 @@ import gradio as gr
|
|
12 |
import bcrypt
|
13 |
from config.settings import HF_API_TOKEN
|
14 |
from huggingface_hub import snapshot_download
|
|
|
15 |
# from .blur import blur_faces, detect_faces
|
16 |
from retinaface import RetinaFace
|
17 |
from gradio_modal import Modal
|
@@ -71,13 +73,13 @@ def clear_data(message: Literal["submit", "remove"] | None = None):
|
|
71 |
gr.Info("If you logged in, you will soon see it at the bottom of the page, where you can edit it or delete it", title="Thank you for submitting your data! π", duration=5)
|
72 |
elif message == "remove":
|
73 |
gr.Info("", title="Your data has been deleted! ποΈ", duration=5)
|
74 |
-
return (None, None, None, None, None, gr.update(value=None),
|
75 |
gr.update(value=[]), gr.update(value=[]), gr.update(value=[]),
|
76 |
gr.update(value=[]), gr.update(value=[]))
|
77 |
|
78 |
|
79 |
def exit():
|
80 |
-
return (None, None, None, gr.Dataset(samples=[]), gr.Markdown("**Loading your data, please wait ...**"),
|
81 |
gr.update(value=None), gr.update(value=None), [None, None, "", ""], gr.update(value=None),
|
82 |
gr.update(value=None), gr.update(value=None),
|
83 |
gr.update(value=None), gr.update(value=None), gr.update(value=None),
|
@@ -87,9 +89,8 @@ def exit():
|
|
87 |
def validate_metadata(country, language):
|
88 |
# Perform your validation logic here
|
89 |
if country is None or language is None:
|
90 |
-
return gr.
|
91 |
-
|
92 |
-
return gr.Button("Proceed", interactive=True)
|
93 |
|
94 |
|
95 |
def validate_inputs(image, ori_img, concept): # is_blurred
|
@@ -129,6 +130,30 @@ def validate_inputs(image, ori_img, concept): # is_blurred
|
|
129 |
|
130 |
return gr.Button("Submit", variant="primary", interactive=True), result_image, ori_img # is_blurred
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
def count_words(caption, language):
|
134 |
match language:
|
@@ -152,8 +177,14 @@ def add_prefix(example, column_name, prefix):
|
|
152 |
example[column_name] = (f"{prefix}/" + example[column_name])
|
153 |
return example
|
154 |
|
155 |
-
def update_user_data(
|
|
|
|
|
|
|
|
|
|
|
156 |
|
|
|
157 |
datasets_list = []
|
158 |
# Try loading local dataset
|
159 |
try:
|
@@ -191,18 +222,19 @@ def update_user_data(username, password, country, language_choice, HF_DATASET_NA
|
|
191 |
# Handle all empty
|
192 |
if not datasets_list:
|
193 |
if username: # User is logged in but has no data
|
194 |
-
return gr.Dataset(samples=[]), gr.Markdown("<p style='color: red;'>No data available for this user. Please upload an image.</p>")
|
195 |
else: # No user logged in
|
196 |
-
return gr.Dataset(samples=[]), gr.Markdown("")
|
197 |
|
198 |
dataset = concatenate_datasets(datasets_list)
|
199 |
# TODO: we should link username with password and language and country, otherwise there will be an error when loading with different language and clicking on the example
|
200 |
-
if username
|
201 |
-
user_dataset = dataset.filter(lambda x: x['username'] == username
|
202 |
user_dataset = user_dataset.sort('timestamp', reverse=True)
|
203 |
# Show only unique entries (most recent)
|
204 |
user_ids = set()
|
205 |
samples = []
|
|
|
206 |
for d in user_dataset:
|
207 |
if d['id'] in user_ids:
|
208 |
continue
|
@@ -229,6 +261,10 @@ def update_user_data(username, password, country, language_choice, HF_DATASET_NA
|
|
229 |
d['image_file'], d['image_url'], d['caption'] or "", d['country'],
|
230 |
d['language'], d['category'], d['concept'], additional_concepts_by_category, d['id']] # d['is_blurred']
|
231 |
)
|
|
|
|
|
|
|
|
|
232 |
# return gr.Dataset(samples=samples), None
|
233 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
234 |
# Clean up the βAdditional Conceptsβ column (index 7)
|
@@ -255,10 +291,14 @@ def update_user_data(username, password, country, language_choice, HF_DATASET_NA
|
|
255 |
row_copy[7] = ", ".join(vals)
|
256 |
cleaned.append(row_copy)
|
257 |
|
258 |
-
|
|
|
|
|
|
|
|
|
259 |
else:
|
260 |
# TODO: should we show the entire dataset instead? What about "other data" tab?
|
261 |
-
return gr.Dataset(samples=[]), None
|
262 |
|
263 |
|
264 |
def update_language(local_storage, metadata_dict, concepts_dict):
|
@@ -357,7 +397,7 @@ def update_intro_language(selected_country, selected_language, intro_markdown, m
|
|
357 |
return gr.Markdown(INTRO_TEXT)
|
358 |
|
359 |
|
360 |
-
def handle_click_example(user_examples, concepts_dict):
|
361 |
# print("handle_click_example")
|
362 |
# print(user_examples)
|
363 |
# ex = [item for item in user_examples]
|
@@ -365,7 +405,6 @@ def handle_click_example(user_examples, concepts_dict):
|
|
365 |
# 1) Turn the flat string in slot 7 back into a list-of-lists
|
366 |
ex = list(user_examples)
|
367 |
raw_ac = ex[7] if len(ex) > 7 else ""
|
368 |
-
|
369 |
country_btn = ex[3]
|
370 |
language_btn = ex[4]
|
371 |
concepts = concepts_dict[country_btn][language_btn]
|
@@ -441,7 +480,13 @@ def handle_click_example(user_examples, concepts_dict):
|
|
441 |
# dropdown_values.append(None)
|
442 |
|
443 |
# Need to return values for each category dropdown
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
# return [
|
446 |
# image_inp,
|
447 |
# image_url_inp,
|
@@ -535,8 +580,8 @@ def blur_selected_faces(image, blur_faces_ids, faces_info, face_img, faces_count
|
|
535 |
parsed_faces_ids = [f"face_{val.split(':')[-1].strip()}" for val in parsed_faces_ids]
|
536 |
|
537 |
# Base blur amount and bounds
|
538 |
-
MIN_BLUR =
|
539 |
-
MAX_BLUR =
|
540 |
|
541 |
blurring_start = time.time()
|
542 |
# Process each face
|
@@ -688,4 +733,28 @@ def check_exclude_fn(image):
|
|
688 |
|
689 |
def has_user_json(username, country,language_choice, local_ds_directory_path):
|
690 |
"""Check if JSON files exist for username pattern."""
|
691 |
-
return bool(glob.glob(os.path.join(local_ds_directory_path, "logged_in_users", country, language_choice, username, "**", "*.json"), recursive=True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import PIL
|
5 |
import requests
|
6 |
from typing import Literal
|
7 |
+
from logic.supabase_client import auth_handler
|
8 |
|
9 |
from datasets import load_dataset, concatenate_datasets, Image
|
10 |
from data.lang2eng_map import lang2eng_mapping
|
|
|
13 |
import bcrypt
|
14 |
from config.settings import HF_API_TOKEN
|
15 |
from huggingface_hub import snapshot_download
|
16 |
+
from logic.vlm import vlm_manager
|
17 |
# from .blur import blur_faces, detect_faces
|
18 |
from retinaface import RetinaFace
|
19 |
from gradio_modal import Modal
|
|
|
73 |
gr.Info("If you logged in, you will soon see it at the bottom of the page, where you can edit it or delete it", title="Thank you for submitting your data! π", duration=5)
|
74 |
elif message == "remove":
|
75 |
gr.Info("", title="Your data has been deleted! ποΈ", duration=5)
|
76 |
+
return (None, None, None, gr.update(value=None), gr.update(value=None, visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True), None, None, gr.update(value=None),
|
77 |
gr.update(value=[]), gr.update(value=[]), gr.update(value=[]),
|
78 |
gr.update(value=[]), gr.update(value=[]))
|
79 |
|
80 |
|
81 |
def exit():
|
82 |
+
return (None, None, None, gr.update(value=None), gr.update(value=None, visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True), gr.Dataset(samples=[]), gr.Markdown("**Loading your data, please wait ...**"),
|
83 |
gr.update(value=None), gr.update(value=None), [None, None, "", ""], gr.update(value=None),
|
84 |
gr.update(value=None), gr.update(value=None),
|
85 |
gr.update(value=None), gr.update(value=None), gr.update(value=None),
|
|
|
89 |
def validate_metadata(country, language):
|
90 |
# Perform your validation logic here
|
91 |
if country is None or language is None:
|
92 |
+
return gr.update(interactive=False)
|
93 |
+
return gr.update(interactive=True)
|
|
|
94 |
|
95 |
|
96 |
def validate_inputs(image, ori_img, concept): # is_blurred
|
|
|
130 |
|
131 |
return gr.Button("Submit", variant="primary", interactive=True), result_image, ori_img # is_blurred
|
132 |
|
133 |
+
def generate_vlm_caption(image, model_name="SmolVLM-500M"): # processor, model
|
134 |
+
"""
|
135 |
+
Generate a caption for the given image using a Vision-Language Model.
|
136 |
+
Uses the global VLMManager for efficient model loading and caching.
|
137 |
+
"""
|
138 |
+
if image is None:
|
139 |
+
gr.Warning("β οΈ Please upload an image first.", duration=5)
|
140 |
+
return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
|
141 |
+
|
142 |
+
try:
|
143 |
+
# Use the global VLMManager to load/get the model
|
144 |
+
vlm_manager.load_model(model_name)
|
145 |
+
caption = vlm_manager.generate_caption(image)
|
146 |
+
except Exception as e:
|
147 |
+
print(f"Error generating caption: {e}. Cleaning up memory and try again.")
|
148 |
+
gr.Warning(f"β οΈ Error generating caption: {e} due to memory issues. Please try again.", duration=5)
|
149 |
+
# vlm_manager.cleanup_memory()
|
150 |
+
return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
|
151 |
+
finally: # For now, let's cleanup memory after each generation
|
152 |
+
vlm_manager.cleanup_memory()
|
153 |
+
|
154 |
+
# print(caption)
|
155 |
+
|
156 |
+
return caption, gr.update(visible=True), gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False)
|
157 |
|
158 |
def count_words(caption, language):
|
159 |
match language:
|
|
|
177 |
example[column_name] = (f"{prefix}/" + example[column_name])
|
178 |
return example
|
179 |
|
180 |
+
def update_user_data(client , country, language_choice, HF_DATASET_NAME, local_ds_directory_path):
|
181 |
+
user_info = auth_handler.is_logged_in(client)
|
182 |
+
print(f"User info: {user_info}")
|
183 |
+
if not user_info['success']:
|
184 |
+
print("User is not logged in or session expired.")
|
185 |
+
return gr.Dataset(samples=[]), None, None
|
186 |
|
187 |
+
username = user_info['email']
|
188 |
datasets_list = []
|
189 |
# Try loading local dataset
|
190 |
try:
|
|
|
222 |
# Handle all empty
|
223 |
if not datasets_list:
|
224 |
if username: # User is logged in but has no data
|
225 |
+
return gr.Dataset(samples=[]), gr.Markdown("<p style='color: red;'>No data available for this user. Please upload an image.</p>"), None
|
226 |
else: # No user logged in
|
227 |
+
return gr.Dataset(samples=[]), gr.Markdown(""), None
|
228 |
|
229 |
dataset = concatenate_datasets(datasets_list)
|
230 |
# TODO: we should link username with password and language and country, otherwise there will be an error when loading with different language and clicking on the example
|
231 |
+
if username:
|
232 |
+
user_dataset = dataset.filter(lambda x: x['username'] == username)
|
233 |
user_dataset = user_dataset.sort('timestamp', reverse=True)
|
234 |
# Show only unique entries (most recent)
|
235 |
user_ids = set()
|
236 |
samples = []
|
237 |
+
vlm_captions = dict()
|
238 |
for d in user_dataset:
|
239 |
if d['id'] in user_ids:
|
240 |
continue
|
|
|
261 |
d['image_file'], d['image_url'], d['caption'] or "", d['country'],
|
262 |
d['language'], d['category'], d['concept'], additional_concepts_by_category, d['id']] # d['is_blurred']
|
263 |
)
|
264 |
+
|
265 |
+
if 'vlm_caption' in d:
|
266 |
+
vlm_captions[d['id']] = d.get('vlm_caption', "")
|
267 |
+
|
268 |
# return gr.Dataset(samples=samples), None
|
269 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
270 |
# Clean up the βAdditional Conceptsβ column (index 7)
|
|
|
291 |
row_copy[7] = ", ".join(vals)
|
292 |
cleaned.append(row_copy)
|
293 |
|
294 |
+
# check if vlm_captions is an empty dictionary
|
295 |
+
if not vlm_captions:
|
296 |
+
vlm_captions = None
|
297 |
+
|
298 |
+
return gr.Dataset(samples=cleaned), None, vlm_captions
|
299 |
else:
|
300 |
# TODO: should we show the entire dataset instead? What about "other data" tab?
|
301 |
+
return gr.Dataset(samples=[]), None, None
|
302 |
|
303 |
|
304 |
def update_language(local_storage, metadata_dict, concepts_dict):
|
|
|
397 |
return gr.Markdown(INTRO_TEXT)
|
398 |
|
399 |
|
400 |
+
def handle_click_example(user_examples, vlm_captions, concepts_dict):
|
401 |
# print("handle_click_example")
|
402 |
# print(user_examples)
|
403 |
# ex = [item for item in user_examples]
|
|
|
405 |
# 1) Turn the flat string in slot 7 back into a list-of-lists
|
406 |
ex = list(user_examples)
|
407 |
raw_ac = ex[7] if len(ex) > 7 else ""
|
|
|
408 |
country_btn = ex[3]
|
409 |
language_btn = ex[4]
|
410 |
concepts = concepts_dict[country_btn][language_btn]
|
|
|
480 |
# dropdown_values.append(None)
|
481 |
|
482 |
# Need to return values for each category dropdown
|
483 |
+
|
484 |
+
vlm_caption = None
|
485 |
+
if vlm_captions:
|
486 |
+
if exampleid_btn in vlm_captions:
|
487 |
+
vlm_caption = vlm_captions[exampleid_btn]
|
488 |
+
|
489 |
+
return [image_inp, image_url_inp, long_caption_inp, exampleid_btn, category_btn, concept_btn] + additional_concepts_by_category + [True] + [vlm_caption] # loading_example flag + vlm_caption
|
490 |
# return [
|
491 |
# image_inp,
|
492 |
# image_url_inp,
|
|
|
580 |
parsed_faces_ids = [f"face_{val.split(':')[-1].strip()}" for val in parsed_faces_ids]
|
581 |
|
582 |
# Base blur amount and bounds
|
583 |
+
MIN_BLUR = 131 # Minimum blur amount (must be odd)
|
584 |
+
MAX_BLUR = 351 # Maximum blur amount (must be odd)
|
585 |
|
586 |
blurring_start = time.time()
|
587 |
# Process each face
|
|
|
733 |
|
734 |
def has_user_json(username, country,language_choice, local_ds_directory_path):
|
735 |
"""Check if JSON files exist for username pattern."""
|
736 |
+
return bool(glob.glob(os.path.join(local_ds_directory_path, "logged_in_users", country, language_choice, username, "**", "*.json"), recursive=True))
|
737 |
+
|
738 |
+
def submit_button_clicked(vlm_output):
|
739 |
+
|
740 |
+
if vlm_output is None or vlm_output == '':
|
741 |
+
return Modal(visible=True), Modal(visible=False)
|
742 |
+
else:
|
743 |
+
return Modal(visible=False), Modal(visible=True)
|
744 |
+
# def submit_button_clicked(vlm_output, save_fn, data_outputs):
|
745 |
+
# if vlm_output is None:
|
746 |
+
# return Modal(visible=True)
|
747 |
+
# else:
|
748 |
+
# try:
|
749 |
+
# save_fn(list(data_outputs.values()))
|
750 |
+
# except Exception as e:
|
751 |
+
# gr.Error(f"β οΈ Error saving data: {e}")
|
752 |
+
|
753 |
+
# try:
|
754 |
+
# image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, exampleid_btn, category_btn, concept_btn, \
|
755 |
+
# category_concept_dropdowns0, category_concept_dropdowns1, category_concept_dropdowns2, category_concept_dropdowns3, \
|
756 |
+
# category_concept_dropdowns4 = clear_data("submit")
|
757 |
+
# except Exception as e:
|
758 |
+
# gr.Error(f"β οΈ Error clearing data: {e}")
|
759 |
+
|
760 |
+
# return Modal(visible=False)
|
logic/supabase_client.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from supabase import create_client, Client
|
3 |
+
import os
|
4 |
+
from config.settings import SUPABASE_URL, SUPABASE_KEY, REDIRECT_TO_URL
|
5 |
+
import traceback
|
6 |
+
from supabase.lib.client_options import ClientOptions
|
7 |
+
|
8 |
+
|
9 |
+
# --- Supabase Authentication Class ---
|
10 |
+
|
11 |
+
class SupabaseAuth:
|
12 |
+
"""A class to handle Supabase authentication logic."""
|
13 |
+
def __init__(self, url: str, key: str):
|
14 |
+
self.url = url
|
15 |
+
self.key = key
|
16 |
+
try:
|
17 |
+
self.client: Client = create_client(url, key)
|
18 |
+
except Exception as e:
|
19 |
+
print(f"Error creating Supabase client: {e}")
|
20 |
+
self.client = None
|
21 |
+
|
22 |
+
def login(self, email: str, password: str):
|
23 |
+
"""
|
24 |
+
Attempts to log in a user and returns a user-specific client.
|
25 |
+
"""
|
26 |
+
if not self.client:
|
27 |
+
return {'success': False, 'data': None, 'message': "Supabase client not initialized."}
|
28 |
+
try:
|
29 |
+
response = self.client.auth.sign_in_with_password({"email": email, "password": password})
|
30 |
+
user_session = response.session
|
31 |
+
|
32 |
+
# Create a new, authenticated client for this user
|
33 |
+
authenticated_client = create_client(
|
34 |
+
self.url,
|
35 |
+
self.key,
|
36 |
+
# options={"headers": {"Authorization": f"Bearer {user_session.access_token}"}}
|
37 |
+
options=ClientOptions(
|
38 |
+
headers={"Authorization": f"Bearer {user_session.access_token}"},
|
39 |
+
)
|
40 |
+
)
|
41 |
+
authenticated_client.auth.set_session(user_session.access_token, user_session.refresh_token)
|
42 |
+
|
43 |
+
session_data = {
|
44 |
+
"refresh_token": user_session.refresh_token,
|
45 |
+
"user_email": user_session.user.email,
|
46 |
+
"client": authenticated_client
|
47 |
+
}
|
48 |
+
return {'success': True, 'data': session_data, 'message': f"Welcome, {user_session.user.email}!"}
|
49 |
+
except Exception as e:
|
50 |
+
# print(f"Error logging in: {e}")
|
51 |
+
# traceback.print_exc()
|
52 |
+
# Handle specific error messages for better user feedback
|
53 |
+
return {'success': False, 'data': None, 'message': f"Login failed: {e}"}
|
54 |
+
|
55 |
+
def sign_up(self, email: str, password: str):
|
56 |
+
"""Signs up a new user."""
|
57 |
+
if not self.client:
|
58 |
+
return {'success': False, 'message': "Supabase client not initialized."}
|
59 |
+
try:
|
60 |
+
# Supabase sign_up returns a session if email confirmation is disabled,
|
61 |
+
# or just a user object if it's enabled. We'll just return a success message.
|
62 |
+
self.client.auth.sign_up({
|
63 |
+
"email": email,
|
64 |
+
"password": password,
|
65 |
+
})
|
66 |
+
return {'success': True, 'message': 'Sign up successful! You can login now.'}
|
67 |
+
except Exception as e:
|
68 |
+
return {'success': False, 'message': f"Sign up failed: {e}"}
|
69 |
+
|
70 |
+
def restore_session(self, refresh_token: str):
|
71 |
+
"""
|
72 |
+
Attempts to restore a session using a refresh token.
|
73 |
+
"""
|
74 |
+
if not self.client:
|
75 |
+
return {'success': False, 'data': None, 'message': "Supabase client not initialized."}
|
76 |
+
try:
|
77 |
+
response = self.client.auth.refresh_session(refresh_token)
|
78 |
+
user_session = response.session
|
79 |
+
|
80 |
+
authenticated_client = create_client(
|
81 |
+
self.url,
|
82 |
+
self.key,
|
83 |
+
options=ClientOptions(
|
84 |
+
headers={"Authorization": f"Bearer {user_session.access_token}"},
|
85 |
+
)
|
86 |
+
)
|
87 |
+
authenticated_client.auth.set_session(user_session.access_token, user_session.refresh_token)
|
88 |
+
|
89 |
+
session_data = {
|
90 |
+
"refresh_token": user_session.refresh_token,
|
91 |
+
"user_email": user_session.user.email,
|
92 |
+
"client": authenticated_client
|
93 |
+
}
|
94 |
+
print("Session restored successfully:", session_data)
|
95 |
+
return {'success': True, 'data': session_data, 'message': f"Welcome, {user_session.user.email}!"}
|
96 |
+
except Exception as e:
|
97 |
+
print("failed to restore session:", e)
|
98 |
+
return {'success': False, 'data': None, 'message': f"Failed to restore session: {e}"}
|
99 |
+
|
100 |
+
def logout(self, user_client: Client):
|
101 |
+
"""Signs out the user from Supabase, invalidating the token."""
|
102 |
+
if not user_client:
|
103 |
+
return {'success': False, 'message': 'No user client provided to log out.'}
|
104 |
+
try:
|
105 |
+
user_client.auth.sign_out()
|
106 |
+
return {'success': True, 'message': 'Successfully signed out from Supabase.'}
|
107 |
+
except Exception as e:
|
108 |
+
# It's often safe to ignore errors here (e.g., if token already expired)
|
109 |
+
# but we'll log it for debugging.
|
110 |
+
print(f"Error signing out from Supabase: {e}")
|
111 |
+
return {'success': False, 'message': f'Error signing out: {e}'}
|
112 |
+
|
113 |
+
def change_password(self, user_client: Client, new_password: str):
|
114 |
+
"""Changes the user's password."""
|
115 |
+
if not user_client:
|
116 |
+
return {'success': False, 'message': 'No user client provided to change password.'}
|
117 |
+
try:
|
118 |
+
user_client.auth.update_user({"password": new_password})
|
119 |
+
return {'success': True, 'message': 'Password changed successfully.'}
|
120 |
+
except Exception as e:
|
121 |
+
return {'success': False, 'message': f'Error changing password: {e}'}
|
122 |
+
|
123 |
+
def is_logged_in(self, user_client: Client):
|
124 |
+
"""Checks if a user is currently authenticated and returns their email."""
|
125 |
+
print("Checking if user is logged in...", user_client)
|
126 |
+
if not user_client:
|
127 |
+
return {'success': False, 'email': None, 'message': 'No user client provided.'}
|
128 |
+
try:
|
129 |
+
user_response = user_client.auth.get_user()
|
130 |
+
user = user_response.user
|
131 |
+
if user:
|
132 |
+
return {'success': True, 'email': user.email, 'message': f'Logged in as: {user.email}'}
|
133 |
+
else:
|
134 |
+
return {'success': False, 'email': None, 'message': 'User is not logged in.'}
|
135 |
+
except Exception as e:
|
136 |
+
# This might happen if the token has expired and can't be refreshed.
|
137 |
+
return {'success': False, 'email': None, 'message': f'Authentication check failed: {e}'}
|
138 |
+
|
139 |
+
def reset_password_for_email(self, email: str):
|
140 |
+
"""
|
141 |
+
Sends a password reset email to the specified address.
|
142 |
+
"""
|
143 |
+
if not self.client:
|
144 |
+
return {'success': False, 'message': "Supabase client not initialized."}
|
145 |
+
try:
|
146 |
+
self.client.auth.reset_password_for_email(
|
147 |
+
email,
|
148 |
+
{
|
149 |
+
"redirect_to": str(REDIRECT_TO_URL),
|
150 |
+
}
|
151 |
+
)
|
152 |
+
return {'success': True, 'message': "Password reset email sent. Check your inbox!"}
|
153 |
+
except Exception as e:
|
154 |
+
return {'success': False, 'message': f"Failed to send reset email: {e}"}
|
155 |
+
|
156 |
+
def retrieve_session_from_tokens(self, access_token: str, refresh_token: str):
|
157 |
+
"""
|
158 |
+
Retrieves a session from an access token and refresh token.
|
159 |
+
This is typically used after a password recovery link is clicked.
|
160 |
+
"""
|
161 |
+
if not self.client:
|
162 |
+
return {'success': False, 'data': None, 'message': "Supabase client not initialized."}
|
163 |
+
try:
|
164 |
+
# Set the session on the main client to verify tokens and get user info
|
165 |
+
self.client.auth.set_session(access_token, refresh_token)
|
166 |
+
user_response = self.client.auth.get_user()
|
167 |
+
user = user_response.user
|
168 |
+
|
169 |
+
if not user:
|
170 |
+
return {'success': False, 'data': None, 'message': "Could not retrieve user from tokens."}
|
171 |
+
|
172 |
+
# Create a new, authenticated client for this user, similar to login
|
173 |
+
authenticated_client = create_client(
|
174 |
+
self.url,
|
175 |
+
self.key,
|
176 |
+
options=ClientOptions(
|
177 |
+
headers={"Authorization": f"Bearer {access_token}"},
|
178 |
+
)
|
179 |
+
)
|
180 |
+
authenticated_client.auth.set_session(access_token, refresh_token)
|
181 |
+
|
182 |
+
session_data = {
|
183 |
+
"refresh_token": refresh_token,
|
184 |
+
"user_email": user.email,
|
185 |
+
"client": authenticated_client
|
186 |
+
}
|
187 |
+
return {'success': True, 'data': session_data, 'message': f"Welcome, {user.email}!"}
|
188 |
+
except Exception as e:
|
189 |
+
return {'success': False, 'data': None, 'message': f"Failed to retrieve session from tokens: {e}"}
|
190 |
+
|
191 |
+
|
192 |
+
auth_handler = SupabaseAuth(SUPABASE_URL, SUPABASE_KEY)
|
193 |
+
|
logic/vlm.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, torchvision.transforms as T
|
2 |
+
from torchvision.transforms.functional import InterpolationMode
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import TorchAoConfig, Qwen2_5_VLForConditionalGeneration, Gemma3ForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, AutoModel
|
5 |
+
from qwen_vl_utils import process_vision_info
|
6 |
+
import gc
|
7 |
+
# from transformers.image_utils import load_image
|
8 |
+
|
9 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
10 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
11 |
+
|
12 |
+
class VLMManager:
|
13 |
+
"""
|
14 |
+
A manager class for Vision-Language Models that handles model loading,
|
15 |
+
caching, and dynamic switching between different models.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, default_model: str = "Gemma3-4B"):
|
19 |
+
"""
|
20 |
+
Initialize the VLM Manager with a default model.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
default_model (str): The default model to load initially.
|
24 |
+
"""
|
25 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
self.current_model_name = None
|
27 |
+
self.processor = None
|
28 |
+
self.tokenizer = None # Initialize tokenizer attribute
|
29 |
+
self.model = None
|
30 |
+
|
31 |
+
self.system_message = """
|
32 |
+
You are an expert cultural-aware image-analysis assistant. For every image:
|
33 |
+
1. Output exactly 40 words in total.
|
34 |
+
2. Use a single paragraph (no lists or bullet points).
|
35 |
+
3. Describe Who (appearance/emotion), What (action), and Where (setting).
|
36 |
+
4. Do NOT include opinions or speculations.
|
37 |
+
5. If you go over 40 words, shorten or remove non-essential details.
|
38 |
+
"""
|
39 |
+
|
40 |
+
self.user_prompt = """
|
41 |
+
Given this image, please provide an image description of around 40 words with extensive and detailed visual information.
|
42 |
+
|
43 |
+
Descriptions must be objective: focus on how you would describe the image to someone who can't see it, without your own opinions/speculations.
|
44 |
+
|
45 |
+
The text needs to include the main concept and describe the content of the image in detail by including:
|
46 |
+
- Who?: The visual appearance and observable emotions (e.g., "is smiling") of persons and animals.
|
47 |
+
- What?: The actions performed in the image.
|
48 |
+
- Where?: The setting of the image, including the size, color, and relationships between objects.
|
49 |
+
"""
|
50 |
+
|
51 |
+
# Load the default model
|
52 |
+
self.load_model(default_model)
|
53 |
+
|
54 |
+
def load_model(self, model_name: str):
|
55 |
+
"""
|
56 |
+
Load a VLM model. If the model is already loaded, return the cached version.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
model_name (str): The name of the model to load.
|
60 |
+
"""
|
61 |
+
# If the requested model is already loaded, no need to reload
|
62 |
+
if self.current_model_name == model_name and self.model is not None:
|
63 |
+
print(f"Model {model_name} is already loaded, using cached version.")
|
64 |
+
if self.current_model_name == "InternVL3_5-8B":
|
65 |
+
return self.tokenizer, self.model
|
66 |
+
else:
|
67 |
+
return self.processor, self.model
|
68 |
+
|
69 |
+
print(f"Loading model: {model_name}")
|
70 |
+
|
71 |
+
# Clear current model from memory if exists
|
72 |
+
if self.model is not None:
|
73 |
+
del self.model
|
74 |
+
self.model = None
|
75 |
+
if self.current_model_name == "InternVL3_5-8B":
|
76 |
+
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
77 |
+
del self.tokenizer
|
78 |
+
self.tokenizer = None
|
79 |
+
else:
|
80 |
+
if hasattr(self, 'processor') and self.processor is not None:
|
81 |
+
del self.processor
|
82 |
+
self.processor = None
|
83 |
+
# Force garbage collection and clear CUDA cache
|
84 |
+
gc.collect()
|
85 |
+
if torch.cuda.is_available():
|
86 |
+
torch.cuda.empty_cache()
|
87 |
+
torch.cuda.synchronize() # Wait for all operations to complete
|
88 |
+
|
89 |
+
# Load the new model
|
90 |
+
if model_name == "SmolVLM-500M":
|
91 |
+
self.processor, self.model = self._load_smolvlm_model("HuggingFaceTB/SmolVLM-500M-Instruct")
|
92 |
+
elif model_name == "Qwen2.5-VL-7B":
|
93 |
+
self.processor, self.model = self._load_qwen25_model("Qwen/Qwen2.5-VL-7B-Instruct")
|
94 |
+
elif model_name == "InternVL3_5-8B":
|
95 |
+
self.tokenizer, self.model = self._load_internvl35_model("OpenGVLab/InternVL3_5-8B-Instruct")
|
96 |
+
elif model_name == "Gemma3-4B":
|
97 |
+
self.processor, self.model = self._load_gemma3_model("google/gemma-3-4b-it")
|
98 |
+
else:
|
99 |
+
raise ValueError(f"Model {model_name} is not supported or not available.")
|
100 |
+
|
101 |
+
self.current_model_name = model_name
|
102 |
+
print(f"Successfully loaded model: {model_name}")
|
103 |
+
|
104 |
+
def generate_caption(self, image):
|
105 |
+
"""
|
106 |
+
Generate a caption for the given image using the loaded model.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
processor: The processor for the model.
|
110 |
+
model: The model to use for generating the caption.
|
111 |
+
image: The image to generate a caption for.
|
112 |
+
"""
|
113 |
+
if self.current_model_name == "SmolVLM-500M":
|
114 |
+
return self._inference_smolvlm_model(image)
|
115 |
+
elif self.current_model_name == "Qwen2.5-VL-7B":
|
116 |
+
return self._inference_qwen25_model(image)
|
117 |
+
elif self.current_model_name == "InternVL3_5-8B":
|
118 |
+
return self._inference_internvl35_model(image)
|
119 |
+
elif self.current_model_name == "Gemma3-4B":
|
120 |
+
return self._inference_gemma3_model(image)
|
121 |
+
else:
|
122 |
+
raise ValueError(f"Model {self.current_model_name} is not supported or not available.")
|
123 |
+
|
124 |
+
def get_current_model(self):
|
125 |
+
"""
|
126 |
+
Get the currently loaded model and processor.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
tuple: A tuple containing (processor, model, model_name).
|
130 |
+
"""
|
131 |
+
return self.processor, self.model, self.current_model_name
|
132 |
+
|
133 |
+
def cleanup_memory(self):
|
134 |
+
"""
|
135 |
+
Explicit memory cleanup method that can be called to free GPU memory.
|
136 |
+
"""
|
137 |
+
if self.model is not None:
|
138 |
+
del self.model
|
139 |
+
self.model = None
|
140 |
+
if hasattr(self, 'processor') and self.processor is not None:
|
141 |
+
del self.processor
|
142 |
+
self.processor = None
|
143 |
+
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
144 |
+
del self.tokenizer
|
145 |
+
self.tokenizer = None
|
146 |
+
|
147 |
+
self.current_model_name = None
|
148 |
+
|
149 |
+
# Force cleanup
|
150 |
+
gc.collect()
|
151 |
+
if torch.cuda.is_available():
|
152 |
+
torch.cuda.empty_cache()
|
153 |
+
torch.cuda.synchronize()
|
154 |
+
|
155 |
+
print("Memory cleanup completed.")
|
156 |
+
|
157 |
+
#########################################################
|
158 |
+
## Load functions
|
159 |
+
|
160 |
+
def _load_smolvlm_model(self, model_name):
|
161 |
+
"""Load SmolVLM model."""
|
162 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
163 |
+
model = AutoModelForVision2Seq.from_pretrained(
|
164 |
+
model_name,
|
165 |
+
_attn_implementation="eager"
|
166 |
+
).to(self.device)
|
167 |
+
model.eval()
|
168 |
+
return processor, model
|
169 |
+
|
170 |
+
def _load_qwen25_model(self, model_name):
|
171 |
+
"""Load Qwen2.5-VL model."""
|
172 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
173 |
+
model_name, torch_dtype="auto", device_map="auto"
|
174 |
+
)
|
175 |
+
|
176 |
+
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
177 |
+
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
178 |
+
# "Qwen/Qwen2.5-VL-7B-Instruct",
|
179 |
+
# torch_dtype=torch.bfloat16,
|
180 |
+
# attn_implementation="flash_attention_2",
|
181 |
+
# device_map="auto",
|
182 |
+
# )
|
183 |
+
|
184 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
185 |
+
model.eval()
|
186 |
+
return processor, model
|
187 |
+
|
188 |
+
def _load_internvl35_model(self, model_name):
|
189 |
+
"""Load InternVL3.5 model."""
|
190 |
+
# Load tokenizer (InternVL uses tokenizer instead of processor for text)
|
191 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
192 |
+
|
193 |
+
# Load the model using AutoModel
|
194 |
+
model = AutoModel.from_pretrained(
|
195 |
+
model_name,
|
196 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
|
197 |
+
low_cpu_mem_usage=True,
|
198 |
+
use_flash_attn=False, # True set False if CUDA mismatch
|
199 |
+
trust_remote_code=True,
|
200 |
+
device_map="auto"
|
201 |
+
)
|
202 |
+
|
203 |
+
model.eval()
|
204 |
+
|
205 |
+
# Return tokenizer as processor for consistency with the interface
|
206 |
+
return tokenizer, model
|
207 |
+
|
208 |
+
def _load_gemma3_model(self, model_name):
|
209 |
+
"""Load Gemma3 model."""
|
210 |
+
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
211 |
+
model = Gemma3ForConditionalGeneration.from_pretrained(
|
212 |
+
model_name,
|
213 |
+
device_map="auto",
|
214 |
+
quantization_config=quantization_config
|
215 |
+
)
|
216 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
217 |
+
model.eval()
|
218 |
+
return processor, model
|
219 |
+
|
220 |
+
#########################################################
|
221 |
+
## Inference functions
|
222 |
+
def check_processor_and_model(self):
|
223 |
+
if self.processor is None or self.model is None:
|
224 |
+
raise ValueError("Processor and model must be loaded before generating a caption.")
|
225 |
+
|
226 |
+
def _inference_qwen25_model(self, image):
|
227 |
+
"""Inference Qwen2.5-VL model."""
|
228 |
+
self.check_processor_and_model()
|
229 |
+
messages = [
|
230 |
+
{
|
231 |
+
"role": "system",
|
232 |
+
"content": [{"type": "text", "text": self.system_message}]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"role": "user",
|
236 |
+
"content": [
|
237 |
+
{
|
238 |
+
"type": "image",
|
239 |
+
"image": Image.fromarray(image),
|
240 |
+
},
|
241 |
+
{"type": "text", "text": self.user_prompt},
|
242 |
+
],
|
243 |
+
}
|
244 |
+
]
|
245 |
+
|
246 |
+
# Preparation for inference
|
247 |
+
text = self.processor.apply_chat_template(
|
248 |
+
messages, tokenize=False, add_generation_prompt=True
|
249 |
+
)
|
250 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
251 |
+
inputs = self.processor(
|
252 |
+
text=[text],
|
253 |
+
images=image_inputs,
|
254 |
+
videos=video_inputs,
|
255 |
+
padding=True,
|
256 |
+
return_tensors="pt",
|
257 |
+
)
|
258 |
+
inputs = inputs.to(self.model.device)
|
259 |
+
|
260 |
+
# Inference: Generation of the output
|
261 |
+
generated_ids = self.model.generate(**inputs, max_new_tokens=128)
|
262 |
+
generated_ids_trimmed = [
|
263 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
264 |
+
]
|
265 |
+
caption = self.processor.batch_decode(
|
266 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
267 |
+
)[0]
|
268 |
+
|
269 |
+
# Clean up tensors to free GPU memory
|
270 |
+
del inputs, generated_ids, generated_ids_trimmed
|
271 |
+
if torch.cuda.is_available():
|
272 |
+
torch.cuda.empty_cache()
|
273 |
+
|
274 |
+
return caption
|
275 |
+
|
276 |
+
def _inference_gemma3_model(self, image):
|
277 |
+
"""Inference Gemma3 model."""
|
278 |
+
self.check_processor_and_model()
|
279 |
+
messages = [
|
280 |
+
{
|
281 |
+
"role": "system",
|
282 |
+
"content": [{"type": "text", "text": self.system_message}]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"role": "user",
|
286 |
+
"content": [
|
287 |
+
{"type": "image", "image": Image.fromarray(image)},
|
288 |
+
{"type": "text", "text": self.user_prompt}
|
289 |
+
]
|
290 |
+
}
|
291 |
+
]
|
292 |
+
|
293 |
+
inputs = self.processor.apply_chat_template(
|
294 |
+
messages, add_generation_prompt=True, tokenize=True,
|
295 |
+
return_dict=True, return_tensors="pt"
|
296 |
+
).to(self.model.device, dtype=torch.bfloat16)
|
297 |
+
|
298 |
+
input_len = inputs["input_ids"].shape[-1]
|
299 |
+
|
300 |
+
with torch.inference_mode():
|
301 |
+
generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
302 |
+
generation = generation[0][input_len:]
|
303 |
+
|
304 |
+
caption = self.processor.decode(generation, skip_special_tokens=True)
|
305 |
+
|
306 |
+
# Clean up tensors to free GPU memory
|
307 |
+
del inputs, generation
|
308 |
+
if torch.cuda.is_available():
|
309 |
+
torch.cuda.empty_cache()
|
310 |
+
|
311 |
+
return caption
|
312 |
+
|
313 |
+
def _inference_smolvlm_model(self, image):
|
314 |
+
self.check_processor_and_model()
|
315 |
+
messages = [
|
316 |
+
{
|
317 |
+
"role": "system",
|
318 |
+
"content": self.system_message
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"role": "user",
|
322 |
+
"content": [
|
323 |
+
{"type": "image"},
|
324 |
+
{"type": "text", "text": self.user_prompt}
|
325 |
+
]
|
326 |
+
}
|
327 |
+
]
|
328 |
+
|
329 |
+
# Prepare inputs
|
330 |
+
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
331 |
+
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
|
332 |
+
inputs = inputs.to(self.model.device)
|
333 |
+
|
334 |
+
# Generate outputs
|
335 |
+
gen_kwargs = {
|
336 |
+
"max_new_tokens": 200, # plenty for ~40 words
|
337 |
+
# "early_stopping": True, # stop at first EOS
|
338 |
+
# "no_repeat_ngram_size": 3, # discourage loops
|
339 |
+
# "length_penalty": 0.8, # slightly favor brevity
|
340 |
+
# "eos_token_id": processor.tokenizer.eos_token_id,
|
341 |
+
# "pad_token_id": processor.tokenizer.eos_token_id,
|
342 |
+
}
|
343 |
+
generated_ids = self.model.generate(**inputs, **gen_kwargs) # max_new_tokens=500)
|
344 |
+
generated_texts = self.processor.batch_decode(
|
345 |
+
generated_ids,
|
346 |
+
skip_special_tokens=True,
|
347 |
+
)[0]
|
348 |
+
|
349 |
+
# Extract only what the assistant said
|
350 |
+
if "Assistant:" in generated_texts:
|
351 |
+
caption = generated_texts.split("Assistant:", 1)[1].strip()
|
352 |
+
else:
|
353 |
+
caption = generated_texts.strip()
|
354 |
+
|
355 |
+
# Clean up tensors to free GPU memory
|
356 |
+
del inputs, generated_ids
|
357 |
+
if torch.cuda.is_available():
|
358 |
+
torch.cuda.empty_cache()
|
359 |
+
|
360 |
+
return caption
|
361 |
+
|
362 |
+
def _inference_internvl35_model(self, image):
|
363 |
+
if self.tokenizer is None:
|
364 |
+
raise ValueError("Tokenizer must be loaded before generating a caption for InternVL3.5.")
|
365 |
+
# image can be numpy (H,W,3) or PIL.Image
|
366 |
+
if hasattr(image, "shape"): # numpy array
|
367 |
+
pil_image = Image.fromarray(image.astype("uint8"), mode="RGB")
|
368 |
+
else:
|
369 |
+
pil_image = image
|
370 |
+
|
371 |
+
pixel_values = self._image_to_pixel_values(pil_image, size=448, max_num=12)
|
372 |
+
pixel_values = pixel_values.to(dtype=torch.bfloat16, device=self.model.device)
|
373 |
+
|
374 |
+
# Format question with image token (matches official docs)
|
375 |
+
question = "<image>\n" + self.user_prompt
|
376 |
+
|
377 |
+
# Generation config matching official examples
|
378 |
+
gen_cfg = dict(
|
379 |
+
max_new_tokens=128,
|
380 |
+
do_sample=False,
|
381 |
+
temperature=0.0,
|
382 |
+
# Optional: add other parameters from docs
|
383 |
+
# top_p=0.9,
|
384 |
+
# repetition_penalty=1.1
|
385 |
+
)
|
386 |
+
|
387 |
+
# Use model's chat method (official approach)
|
388 |
+
response = self.model.chat(self.tokenizer, pixel_values, question, gen_cfg)
|
389 |
+
|
390 |
+
# Clean up tensors to free GPU memory
|
391 |
+
del pixel_values
|
392 |
+
if torch.cuda.is_available():
|
393 |
+
torch.cuda.empty_cache()
|
394 |
+
|
395 |
+
return response.strip()
|
396 |
+
|
397 |
+
def _image_to_pixel_values(self, img, size=448, max_num=12):
|
398 |
+
transform = self._build_transform(size)
|
399 |
+
tiles = self._dynamic_preprocess(img, image_size=size, max_num=max_num, use_thumbnail=True)
|
400 |
+
pixel_values = torch.stack([transform(t) for t in tiles])
|
401 |
+
return pixel_values
|
402 |
+
|
403 |
+
|
404 |
+
def _dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
|
405 |
+
# same logic as the model card: split into tiles based on aspect ratio
|
406 |
+
w, h = image.size
|
407 |
+
aspect = w / h
|
408 |
+
targets = sorted({(i, j) for n in range(min_num, max_num+1)
|
409 |
+
for i in range(1, n+1) for j in range(1, n+1)
|
410 |
+
if i*j <= max_num and i*j >= min_num},
|
411 |
+
key=lambda x: x[0]*x[1])
|
412 |
+
|
413 |
+
# pick closest ratio
|
414 |
+
best = min(targets, key=lambda r: abs(aspect - r[0]/r[1]))
|
415 |
+
tw, th = image_size * best[0], image_size * best[1]
|
416 |
+
resized = image.resize((tw, th))
|
417 |
+
|
418 |
+
tiles = []
|
419 |
+
for i in range(best[0] * best[1]):
|
420 |
+
box = ((i % (tw // image_size)) * image_size,
|
421 |
+
(i // (tw // image_size)) * image_size,
|
422 |
+
((i % (tw // image_size)) + 1) * image_size,
|
423 |
+
((i // (tw // image_size)) + 1) * image_size)
|
424 |
+
tiles.append(resized.crop(box))
|
425 |
+
|
426 |
+
if use_thumbnail and len(tiles) != 1:
|
427 |
+
tiles.append(image.resize((image_size, image_size)))
|
428 |
+
return tiles
|
429 |
+
|
430 |
+
def _build_transform(self, input_size=448):
|
431 |
+
return T.Compose([
|
432 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
433 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
434 |
+
T.ToTensor(),
|
435 |
+
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
436 |
+
])
|
437 |
+
|
438 |
+
|
439 |
+
# Global VLM Manager instance
|
440 |
+
vlm_manager = VLMManager()
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
absl-py==2.2.2
|
|
|
2 |
aiofiles==23.2.1
|
3 |
aiohappyeyeballs==2.6.1
|
4 |
aiohttp==3.11.16
|
@@ -8,15 +9,25 @@ anyio==4.9.0
|
|
8 |
astunparse==1.6.3
|
9 |
async-timeout==5.0.1
|
10 |
attrs==25.3.0
|
|
|
11 |
bcrypt==4.3.0
|
12 |
beautifulsoup4==4.13.3
|
|
|
|
|
|
|
13 |
certifi==2025.1.31
|
14 |
charset-normalizer==3.4.1
|
15 |
click==8.1.8
|
|
|
|
|
16 |
cycler==0.12.1
|
|
|
17 |
datasets==3.5.0
|
|
|
18 |
deep-translator==1.11.4
|
|
|
19 |
dill==0.3.8
|
|
|
20 |
et_xmlfile==2.0.0
|
21 |
exceptiongroup==1.2.2
|
22 |
fastapi==0.115.12
|
@@ -41,16 +52,36 @@ huggingface-hub==0.30.1
|
|
41 |
idna==3.10
|
42 |
Jinja2==3.1.6
|
43 |
keras==3.9.2
|
|
|
|
|
44 |
libclang==18.1.1
|
|
|
45 |
Markdown==3.8
|
46 |
markdown-it-py==3.0.0
|
47 |
MarkupSafe==3.0.2
|
48 |
mdurl==0.1.2
|
49 |
ml_dtypes==0.5.1
|
|
|
50 |
multidict==6.3.2
|
51 |
multiprocess==0.70.16
|
|
|
52 |
namex==0.0.8
|
|
|
53 |
numpy==2.1.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
opencv-python==4.11.0.86
|
55 |
openpyxl==3.1.5
|
56 |
opt_einsum==3.4.0
|
@@ -59,54 +90,78 @@ orjson==3.10.16
|
|
59 |
packaging==24.2
|
60 |
pandas==2.2.3
|
61 |
pillow==11.1.0
|
|
|
|
|
62 |
propcache==0.3.1
|
63 |
protobuf==5.29.4
|
|
|
64 |
pyarrow==19.0.1
|
65 |
-
pydantic
|
66 |
-
pydantic_core
|
67 |
pydub==0.25.1
|
68 |
Pygments==2.19.1
|
69 |
PySocks==1.7.1
|
|
|
70 |
python-dateutil==2.9.0.post0
|
71 |
python-dotenv==1.1.0
|
72 |
python-multipart==0.0.20
|
73 |
pytz==2025.2
|
|
|
74 |
PyYAML==6.0.2
|
|
|
|
|
75 |
requests==2.32.3
|
76 |
retina-face==0.0.17
|
77 |
rich==14.0.0
|
78 |
ruff==0.11.4
|
79 |
safehttpx==0.1.6
|
|
|
80 |
semantic-version==2.10.0
|
81 |
shellingham==1.5.4
|
82 |
six==1.17.0
|
|
|
83 |
sniffio==1.3.1
|
84 |
soupsieve==2.6
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
starlette==0.46.1
|
|
|
86 |
tensorboard==2.19.0
|
87 |
tensorboard-data-server==0.7.2
|
88 |
tensorflow==2.19.0
|
89 |
tensorflow-io-gcs-filesystem==0.37.1
|
90 |
termcolor==3.0.1
|
91 |
tf_keras==2.19.0
|
|
|
|
|
|
|
92 |
tomlkit==0.13.2
|
|
|
|
|
|
|
|
|
93 |
tqdm==4.67.1
|
|
|
|
|
94 |
typer==0.15.2
|
95 |
-
typing-inspection
|
96 |
-
typing_extensions
|
97 |
tzdata==2025.2
|
|
|
98 |
urllib3==2.3.0
|
99 |
uvicorn==0.34.0
|
|
|
|
|
100 |
websockets==15.0.1
|
101 |
Werkzeug==3.1.3
|
102 |
wrapt==1.17.2
|
103 |
xxhash==3.5.0
|
104 |
yarl==1.19.0
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
spacy-legacy==3.0.12
|
110 |
-
spacy-loggers==1.0.5
|
111 |
-
spacy_thai==0.7.8
|
112 |
-
spacy-udpipe==1.0.0
|
|
|
1 |
absl-py==2.2.2
|
2 |
+
accelerate==1.9.0
|
3 |
aiofiles==23.2.1
|
4 |
aiohappyeyeballs==2.6.1
|
5 |
aiohttp==3.11.16
|
|
|
9 |
astunparse==1.6.3
|
10 |
async-timeout==5.0.1
|
11 |
attrs==25.3.0
|
12 |
+
av==15.1.0
|
13 |
bcrypt==4.3.0
|
14 |
beautifulsoup4==4.13.3
|
15 |
+
bitsandbytes==0.46.1
|
16 |
+
blis==1.3.0
|
17 |
+
catalogue==2.0.10
|
18 |
certifi==2025.1.31
|
19 |
charset-normalizer==3.4.1
|
20 |
click==8.1.8
|
21 |
+
cloudpathlib==0.21.1
|
22 |
+
confection==0.1.5
|
23 |
cycler==0.12.1
|
24 |
+
cymem==2.0.11
|
25 |
datasets==3.5.0
|
26 |
+
decord==0.6.0
|
27 |
deep-translator==1.11.4
|
28 |
+
deplacy==2.1.0
|
29 |
dill==0.3.8
|
30 |
+
einops==0.8.1
|
31 |
et_xmlfile==2.0.0
|
32 |
exceptiongroup==1.2.2
|
33 |
fastapi==0.115.12
|
|
|
52 |
idna==3.10
|
53 |
Jinja2==3.1.6
|
54 |
keras==3.9.2
|
55 |
+
langcodes==3.5.0
|
56 |
+
language_data==1.3.0
|
57 |
libclang==18.1.1
|
58 |
+
marisa-trie==1.2.1
|
59 |
Markdown==3.8
|
60 |
markdown-it-py==3.0.0
|
61 |
MarkupSafe==3.0.2
|
62 |
mdurl==0.1.2
|
63 |
ml_dtypes==0.5.1
|
64 |
+
mpmath==1.3.0
|
65 |
multidict==6.3.2
|
66 |
multiprocess==0.70.16
|
67 |
+
murmurhash==1.0.13
|
68 |
namex==0.0.8
|
69 |
+
networkx==3.4.2
|
70 |
numpy==2.1.3
|
71 |
+
nvidia-cublas-cu12==12.6.4.1
|
72 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
73 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
74 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
75 |
+
nvidia-cudnn-cu12==9.5.1.17
|
76 |
+
nvidia-cufft-cu12==11.3.0.4
|
77 |
+
nvidia-cufile-cu12==1.11.1.6
|
78 |
+
nvidia-curand-cu12==10.3.7.77
|
79 |
+
nvidia-cusolver-cu12==11.7.1.2
|
80 |
+
nvidia-cusparse-cu12==12.5.4.2
|
81 |
+
nvidia-cusparselt-cu12==0.6.3
|
82 |
+
nvidia-nccl-cu12==2.26.2
|
83 |
+
nvidia-nvjitlink-cu12==12.6.85
|
84 |
+
nvidia-nvtx-cu12==12.6.77
|
85 |
opencv-python==4.11.0.86
|
86 |
openpyxl==3.1.5
|
87 |
opt_einsum==3.4.0
|
|
|
90 |
packaging==24.2
|
91 |
pandas==2.2.3
|
92 |
pillow==11.1.0
|
93 |
+
pillow_heif==1.0.0
|
94 |
+
preshed==3.0.10
|
95 |
propcache==0.3.1
|
96 |
protobuf==5.29.4
|
97 |
+
psutil==7.0.0
|
98 |
pyarrow==19.0.1
|
99 |
+
pydantic
|
100 |
+
pydantic_core
|
101 |
pydub==0.25.1
|
102 |
Pygments==2.19.1
|
103 |
PySocks==1.7.1
|
104 |
+
pythainlp==5.1.2
|
105 |
python-dateutil==2.9.0.post0
|
106 |
python-dotenv==1.1.0
|
107 |
python-multipart==0.0.20
|
108 |
pytz==2025.2
|
109 |
+
pyuca==1.2
|
110 |
PyYAML==6.0.2
|
111 |
+
qwen-vl-utils==0.0.8
|
112 |
+
regex==2024.11.6
|
113 |
requests==2.32.3
|
114 |
retina-face==0.0.17
|
115 |
rich==14.0.0
|
116 |
ruff==0.11.4
|
117 |
safehttpx==0.1.6
|
118 |
+
safetensors==0.5.3
|
119 |
semantic-version==2.10.0
|
120 |
shellingham==1.5.4
|
121 |
six==1.17.0
|
122 |
+
smart_open==7.3.0.post1
|
123 |
sniffio==1.3.1
|
124 |
soupsieve==2.6
|
125 |
+
spacy==3.8.7
|
126 |
+
spacy-legacy==3.0.12
|
127 |
+
spacy-loggers==1.0.5
|
128 |
+
spacy-thai==0.7.8
|
129 |
+
spacy-udpipe==1.0.0
|
130 |
+
srsly==2.5.1
|
131 |
starlette==0.46.1
|
132 |
+
sympy==1.14.0
|
133 |
tensorboard==2.19.0
|
134 |
tensorboard-data-server==0.7.2
|
135 |
tensorflow==2.19.0
|
136 |
tensorflow-io-gcs-filesystem==0.37.1
|
137 |
termcolor==3.0.1
|
138 |
tf_keras==2.19.0
|
139 |
+
thinc==8.3.6
|
140 |
+
timm==1.0.19
|
141 |
+
tokenizers==0.21.2
|
142 |
tomlkit==0.13.2
|
143 |
+
torch==2.7.1
|
144 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
|
145 |
+
torchao==0.13.0
|
146 |
+
torchvision==0.22.1
|
147 |
tqdm==4.67.1
|
148 |
+
transformers==4.53.3
|
149 |
+
triton==3.3.1
|
150 |
typer==0.15.2
|
151 |
+
typing-inspection
|
152 |
+
typing_extensions
|
153 |
tzdata==2025.2
|
154 |
+
ufal.udpipe==1.3.1.1
|
155 |
urllib3==2.3.0
|
156 |
uvicorn==0.34.0
|
157 |
+
wasabi==1.1.3
|
158 |
+
weasel==0.4.1
|
159 |
websockets==15.0.1
|
160 |
Werkzeug==3.1.3
|
161 |
wrapt==1.17.2
|
162 |
xxhash==3.5.0
|
163 |
yarl==1.19.0
|
164 |
+
supabase==2.18.1
|
165 |
+
supabase_auth==2.12.3
|
166 |
+
supabase_functions==0.10.1
|
167 |
+
# flash_attn==2.8.1
|
|
|
|
|
|
|
|
ui/layout.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import time
|
|
|
3 |
|
4 |
from logic.data_utils import CustomHFDatasetSaver
|
5 |
from data.lang2eng_map import lang2eng_mapping
|
@@ -12,6 +13,161 @@ from .selection_page import build_selection_page
|
|
12 |
from .main_page import build_main_page
|
13 |
from .main_page import sort_with_pyuca
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def get_key_by_value(dictionary, value):
|
16 |
for key, val in dictionary.items():
|
17 |
if val == value:
|
@@ -100,14 +256,27 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
100 |
object-fit: contain; /* make sure the full image shows */
|
101 |
height: 460px; /* set a fixed height */
|
102 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
"""
|
104 |
############################################################################
|
105 |
with gr.Blocks(css=custom_css) as ui:
|
|
|
|
|
|
|
106 |
local_storage = gr.State([None, None, "", ""])
|
107 |
loading_example = gr.State(False) # to check if the values are loaded from a user click on an example in
|
108 |
# First page: selection
|
109 |
|
110 |
-
selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown = build_selection_page(metadata_dict)
|
111 |
|
112 |
# Second page
|
113 |
cmp_main_ui = build_main_page(concepts_dict, metadata_dict, local_storage)
|
@@ -144,8 +313,20 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
144 |
modal_exclude_confirm = cmp_main_ui["modal_exclude_confirm"]
|
145 |
cancel_exclude_btn = cmp_main_ui["cancel_exclude_btn"]
|
146 |
confirm_exclude_btn = cmp_main_ui["confirm_exclude_btn"]
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
### Category button
|
150 |
category_btn.change(
|
151 |
fn=partial(load_concepts, concepts=concepts_dict),
|
@@ -214,7 +395,7 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
214 |
clear_btn.click(
|
215 |
fn=clear_data,
|
216 |
outputs=[
|
217 |
-
image_inp, image_url_inp, long_caption_inp, exampleid_btn,
|
218 |
category_btn, concept_btn,
|
219 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
220 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
@@ -280,12 +461,12 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
280 |
# Handle clicking on an example
|
281 |
user_examples.click(
|
282 |
fn=partial(handle_click_example, concepts_dict=concepts_dict),
|
283 |
-
inputs=[user_examples],
|
284 |
outputs=[
|
285 |
image_inp, image_url_inp, long_caption_inp, exampleid_btn,
|
286 |
category_btn, concept_btn,
|
287 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
288 |
-
category_concept_dropdowns[3], category_concept_dropdowns[4], loading_example
|
289 |
],
|
290 |
)
|
291 |
|
@@ -295,6 +476,41 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
295 |
|
296 |
# ============================================ #
|
297 |
# Submit Button Click events
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
proceed_btn.click(
|
300 |
fn=partial(switch_ui, flag=False),
|
@@ -313,8 +529,8 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
313 |
]
|
314 |
).then(
|
315 |
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
|
316 |
-
inputs=[
|
317 |
-
outputs=[user_examples, loading_msg],
|
318 |
)
|
319 |
|
320 |
|
@@ -322,7 +538,7 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
322 |
exit_btn.click(
|
323 |
fn=exit,
|
324 |
outputs=[
|
325 |
-
image_inp, image_url_inp, long_caption_inp, user_examples, loading_msg,
|
326 |
username, password, local_storage, exampleid_btn, category_btn, concept_btn,
|
327 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
328 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
@@ -368,7 +584,10 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
368 |
"excluded": gr.State(value=False),
|
369 |
"concepts_dict": gr.State(value=concepts_dict),
|
370 |
"country_lang_map": gr.State(value=lang2eng_mapping),
|
|
|
371 |
# "is_blurred": is_blurred
|
|
|
|
|
372 |
}
|
373 |
# data_outputs = [image_inp, image_url_inp, long_caption_inp,
|
374 |
# country_inp, language_inp, category_btn, concept_btn,
|
@@ -376,34 +595,56 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
376 |
hf_writer.setup(list(data_outputs.keys()), local_ds_folder = LOCAL_DS_DIRECTORY_PATH)
|
377 |
|
378 |
# STEP 4: Chain save_data, then update_user_data, then re-enable button, hide modal, and clear
|
379 |
-
submit_btn.click(
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
# ============================================ #
|
408 |
# instructions button
|
409 |
instruct_btn.click(lambda: Modal(visible=True), None, modal)
|
@@ -446,13 +687,13 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
446 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
447 |
category_concept_dropdowns[3], category_concept_dropdowns[4],
|
448 |
timestamp_btn, username_inp, password_inp, exampleid_btn, gr.State(value=True),
|
449 |
-
gr.State(value=concepts_dict), gr.State(value=lang2eng_mapping)
|
450 |
],
|
451 |
outputs=None
|
452 |
).success(
|
453 |
fn=partial(clear_data, "remove"),
|
454 |
-
outputs=[
|
455 |
-
image_inp, image_url_inp, long_caption_inp, exampleid_btn,
|
456 |
category_btn, concept_btn,
|
457 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
458 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
@@ -465,8 +706,32 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
465 |
outputs=loading_msg
|
466 |
).success(
|
467 |
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path=LOCAL_DS_DIRECTORY_PATH),
|
468 |
-
inputs=[
|
469 |
-
outputs=[user_examples, loading_msg]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
)
|
|
|
471 |
|
472 |
-
return ui
|
|
|
1 |
import gradio as gr
|
2 |
import time
|
3 |
+
from logic.supabase_client import auth_handler
|
4 |
|
5 |
from logic.data_utils import CustomHFDatasetSaver
|
6 |
from data.lang2eng_map import lang2eng_mapping
|
|
|
13 |
from .main_page import build_main_page
|
14 |
from .main_page import sort_with_pyuca
|
15 |
|
16 |
+
js_code = """
|
17 |
+
function() {
|
18 |
+
// Get the full URL with the fragment
|
19 |
+
const url = window.location.href;
|
20 |
+
const fragment = url.split('#')[1];
|
21 |
+
|
22 |
+
if (!fragment) {
|
23 |
+
return "";
|
24 |
+
}
|
25 |
+
|
26 |
+
// Parse the fragment into an object
|
27 |
+
const params = new URLSearchParams(fragment);
|
28 |
+
const access_token = params.get('access_token');
|
29 |
+
const refresh_token = params.get('refresh_token');
|
30 |
+
|
31 |
+
// Create a JSON string with the tokens
|
32 |
+
const tokens = JSON.stringify({
|
33 |
+
access_token: access_token,
|
34 |
+
refresh_token: refresh_token
|
35 |
+
});
|
36 |
+
|
37 |
+
// Return the JSON string to the Gradio output component
|
38 |
+
return tokens;
|
39 |
+
}
|
40 |
+
"""
|
41 |
+
|
42 |
+
def login_user(email, password):
|
43 |
+
result = auth_handler.login(email, password)
|
44 |
+
if result['success']:
|
45 |
+
session_data = result['data']
|
46 |
+
persistent_data = {
|
47 |
+
"refresh_token": session_data['refresh_token'],
|
48 |
+
"user_email": session_data['user_email']
|
49 |
+
}
|
50 |
+
return session_data['client'], persistent_data, result['message']
|
51 |
+
else:
|
52 |
+
persistent_data = {
|
53 |
+
"refresh_token": "",
|
54 |
+
"user_email": ""
|
55 |
+
}
|
56 |
+
return None, persistent_data, result['message']
|
57 |
+
|
58 |
+
def login_user_recovery(session_data: str):
|
59 |
+
"""
|
60 |
+
This function receives session data (tokens as a JSON string) from the frontend,
|
61 |
+
retrieves the session, and returns data in a format similar to login_user.
|
62 |
+
"""
|
63 |
+
try:
|
64 |
+
import json
|
65 |
+
tokens = json.loads(session_data)
|
66 |
+
access_token = tokens.get("access_token")
|
67 |
+
refresh_token = tokens.get("refresh_token")
|
68 |
+
|
69 |
+
if not access_token or not refresh_token:
|
70 |
+
return None, gr.skip(), "Invalid session data provided."
|
71 |
+
|
72 |
+
result = auth_handler.retrieve_session_from_tokens(access_token, refresh_token)
|
73 |
+
|
74 |
+
if result['success']:
|
75 |
+
session_data_result = result['data']
|
76 |
+
persistent_data = {
|
77 |
+
"refresh_token": session_data_result['refresh_token'],
|
78 |
+
"user_email": session_data_result['user_email']
|
79 |
+
}
|
80 |
+
return session_data_result['client'], persistent_data, result['message']
|
81 |
+
else:
|
82 |
+
persistent_data = {
|
83 |
+
"refresh_token": "",
|
84 |
+
"user_email": ""
|
85 |
+
}
|
86 |
+
return None, persistent_data, result['message']
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
return None, gr.skip(), f"Failed to process recovery login: {e}"
|
90 |
+
|
91 |
+
def sign_up(email, password):
|
92 |
+
result = auth_handler.sign_up(email, password)
|
93 |
+
return result['message']
|
94 |
+
|
95 |
+
def reset_password(email):
|
96 |
+
result = auth_handler.reset_password_for_email(email)
|
97 |
+
return result['message']
|
98 |
+
|
99 |
+
def log_out(supabase_user_client, persistent_session):
|
100 |
+
"""
|
101 |
+
Logs out the user and clears the session. If error occurs, it returns an empty persistent session (logging out user).
|
102 |
+
"""
|
103 |
+
persistent_session = {
|
104 |
+
"refresh_token": "",
|
105 |
+
"user_email": ""
|
106 |
+
}
|
107 |
+
if supabase_user_client:
|
108 |
+
result = auth_handler.logout(supabase_user_client)
|
109 |
+
if result['success']:
|
110 |
+
print("User logged out successfully.")
|
111 |
+
return persistent_session
|
112 |
+
else:
|
113 |
+
print(f"Error logging out: {result['message']}")
|
114 |
+
return persistent_session
|
115 |
+
else:
|
116 |
+
print("No user client provided to log out.")
|
117 |
+
return persistent_session
|
118 |
+
|
119 |
+
def restore_user_session(session_data, login_status=None):
|
120 |
+
print("Restoring user session with data:", session_data)
|
121 |
+
# defualt values if the user is not logged in
|
122 |
+
# or the session data is not valid
|
123 |
+
login_status_update = gr.update(value= login_status if login_status else "")
|
124 |
+
proceed_button_update = gr.update(value="Proceed as Anonymous User", interactive=True)
|
125 |
+
login_button_update = gr.update(visible=True)
|
126 |
+
sign_up_button_update = gr.update(visible=True)
|
127 |
+
reset_password_button_update = gr.update(visible=True)
|
128 |
+
logout_button_update = gr.update(visible=False)
|
129 |
+
change_password_field_update = gr.update(visible=False)
|
130 |
+
change_password_field_confirm_update = gr.update(visible=False)
|
131 |
+
change_password_button_update = gr.update(visible=False)
|
132 |
+
change_password_status_update = gr.update(value="")
|
133 |
+
persistent_data = {
|
134 |
+
"refresh_token": "",
|
135 |
+
"user_email": ""
|
136 |
+
}
|
137 |
+
if not session_data or not session_data.get('refresh_token', ''):
|
138 |
+
print("No session data found, proceeding as anonymous user.")
|
139 |
+
return None, persistent_data, login_status_update, proceed_button_update, login_button_update, sign_up_button_update, reset_password_button_update, logout_button_update, change_password_field_update, change_password_field_confirm_update, change_password_button_update, change_password_status_update
|
140 |
+
|
141 |
+
result = auth_handler.restore_session(session_data['refresh_token'])
|
142 |
+
if result['success']:
|
143 |
+
restored_session = result['data']
|
144 |
+
new_persistent_data = {
|
145 |
+
"refresh_token": restored_session['refresh_token'],
|
146 |
+
"user_email": restored_session['user_email']
|
147 |
+
}
|
148 |
+
login_status_update = gr.update(value=result['message'])
|
149 |
+
proceed_button_update = gr.update(value="Proceed", interactive=True)
|
150 |
+
login_button_update = gr.update(visible=False)
|
151 |
+
sign_up_button_update = gr.update(visible=False)
|
152 |
+
reset_password_button_update = gr.update(visible=False)
|
153 |
+
logout_button_update = gr.update(visible=True)
|
154 |
+
change_password_field_update = gr.update(visible=True)
|
155 |
+
change_password_field_confirm_update = gr.update(visible=True)
|
156 |
+
change_password_button_update = gr.update(visible=True)
|
157 |
+
return restored_session['client'], new_persistent_data, login_status_update, proceed_button_update, login_button_update, sign_up_button_update, reset_password_button_update, logout_button_update, change_password_field_update, change_password_field_confirm_update, change_password_button_update, change_password_status_update
|
158 |
+
else:
|
159 |
+
return None, persistent_data, login_status_update, proceed_button_update, login_button_update, sign_up_button_update, reset_password_button_update, logout_button_update, change_password_field_update, change_password_field_confirm_update, change_password_button_update, change_password_status_update
|
160 |
+
|
161 |
+
def change_password(supabase_user_client, new_password, confirm_password):
|
162 |
+
"""
|
163 |
+
Changes the user's password.
|
164 |
+
"""
|
165 |
+
if new_password != confirm_password:
|
166 |
+
return "Passwords do not match. Please try again."
|
167 |
+
result = auth_handler.change_password(supabase_user_client, new_password)
|
168 |
+
return result['message']
|
169 |
+
|
170 |
+
|
171 |
def get_key_by_value(dictionary, value):
|
172 |
for key, val in dictionary.items():
|
173 |
if val == value:
|
|
|
256 |
object-fit: contain; /* make sure the full image shows */
|
257 |
height: 460px; /* set a fixed height */
|
258 |
}
|
259 |
+
#vlm_output .input-container {
|
260 |
+
position: relative;
|
261 |
+
}
|
262 |
+
#vlm_output .input-container::before {
|
263 |
+
content: "";
|
264 |
+
position: absolute;
|
265 |
+
top: 0; left: 0; right: 0; bottom: 0;
|
266 |
+
z-index: 10; /* sits above the textarea */
|
267 |
+
background: transparent;
|
268 |
+
}
|
269 |
"""
|
270 |
############################################################################
|
271 |
with gr.Blocks(css=custom_css) as ui:
|
272 |
+
supabase_user_client = gr.State(None)
|
273 |
+
persistent_session = gr.BrowserState(None)
|
274 |
+
|
275 |
local_storage = gr.State([None, None, "", ""])
|
276 |
loading_example = gr.State(False) # to check if the values are loaded from a user click on an example in
|
277 |
# First page: selection
|
278 |
|
279 |
+
selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown, login_btn, sign_up_btn, reset_password_btn, login_status, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status = build_selection_page(metadata_dict)
|
280 |
|
281 |
# Second page
|
282 |
cmp_main_ui = build_main_page(concepts_dict, metadata_dict, local_storage)
|
|
|
313 |
modal_exclude_confirm = cmp_main_ui["modal_exclude_confirm"]
|
314 |
cancel_exclude_btn = cmp_main_ui["cancel_exclude_btn"]
|
315 |
confirm_exclude_btn = cmp_main_ui["confirm_exclude_btn"]
|
316 |
+
vlm_output = cmp_main_ui["vlm_output"]
|
317 |
+
gen_button = cmp_main_ui["gen_button"]
|
318 |
+
vlm_feedback = cmp_main_ui["vlm_feedback"]
|
319 |
+
modal_vlm = cmp_main_ui["modal_vlm"]
|
320 |
+
vlm_no_btn = cmp_main_ui["vlm_no_btn"]
|
321 |
+
vlm_done_btn = cmp_main_ui["vlm_done_btn"]
|
322 |
+
submit_yes = cmp_main_ui["submit_yes"]
|
323 |
+
submit_no = cmp_main_ui["submit_no"]
|
324 |
+
modal_submit = cmp_main_ui["modal_submit"]
|
325 |
+
vlm_cancel_btn = cmp_main_ui["vlm_cancel_btn"]
|
326 |
+
vlm_model_dropdown = cmp_main_ui["vlm_model_dropdown"]
|
327 |
+
|
328 |
+
# dictionary to store all vlm_output by exampleid
|
329 |
+
vlm_captions = gr.State(None)
|
330 |
### Category button
|
331 |
category_btn.change(
|
332 |
fn=partial(load_concepts, concepts=concepts_dict),
|
|
|
395 |
clear_btn.click(
|
396 |
fn=clear_data,
|
397 |
outputs=[
|
398 |
+
image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, exampleid_btn,
|
399 |
category_btn, concept_btn,
|
400 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
401 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
|
|
461 |
# Handle clicking on an example
|
462 |
user_examples.click(
|
463 |
fn=partial(handle_click_example, concepts_dict=concepts_dict),
|
464 |
+
inputs=[user_examples, vlm_captions],
|
465 |
outputs=[
|
466 |
image_inp, image_url_inp, long_caption_inp, exampleid_btn,
|
467 |
category_btn, concept_btn,
|
468 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
469 |
+
category_concept_dropdowns[3], category_concept_dropdowns[4], loading_example, vlm_output
|
470 |
],
|
471 |
)
|
472 |
|
|
|
476 |
|
477 |
# ============================================ #
|
478 |
# Submit Button Click events
|
479 |
+
login_btn.click(
|
480 |
+
fn=login_user,
|
481 |
+
inputs=[username, password],
|
482 |
+
outputs=[supabase_user_client, persistent_session, login_status],
|
483 |
+
).then(
|
484 |
+
fn=restore_user_session,
|
485 |
+
inputs=[persistent_session, login_status],
|
486 |
+
outputs=[supabase_user_client, persistent_session, login_status, proceed_btn, login_btn, sign_up_btn, reset_password_btn, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status],
|
487 |
+
)
|
488 |
+
|
489 |
+
sign_up_btn.click(
|
490 |
+
fn=sign_up,
|
491 |
+
inputs=[username, password],
|
492 |
+
outputs=[login_status],
|
493 |
+
)
|
494 |
+
|
495 |
+
logout_btn.click(
|
496 |
+
fn=log_out,
|
497 |
+
inputs=[supabase_user_client, persistent_session],
|
498 |
+
outputs=[persistent_session]
|
499 |
+
).then(
|
500 |
+
fn=restore_user_session,
|
501 |
+
inputs=[persistent_session],
|
502 |
+
outputs=[supabase_user_client, persistent_session, login_status, proceed_btn, login_btn, sign_up_btn, reset_password_btn, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status],
|
503 |
+
)
|
504 |
+
change_password_btn.click(
|
505 |
+
fn=change_password,
|
506 |
+
inputs=[supabase_user_client, change_password_field, change_password_field_confirm],
|
507 |
+
outputs=[change_password_status]
|
508 |
+
)
|
509 |
+
reset_password_btn.click(
|
510 |
+
fn=reset_password,
|
511 |
+
inputs=[username],
|
512 |
+
outputs=[login_status]
|
513 |
+
)
|
514 |
|
515 |
proceed_btn.click(
|
516 |
fn=partial(switch_ui, flag=False),
|
|
|
529 |
]
|
530 |
).then(
|
531 |
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
|
532 |
+
inputs=[supabase_user_client, country_choice, language_choice],
|
533 |
+
outputs=[user_examples, loading_msg, vlm_captions],
|
534 |
)
|
535 |
|
536 |
|
|
|
538 |
exit_btn.click(
|
539 |
fn=exit,
|
540 |
outputs=[
|
541 |
+
image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, user_examples, loading_msg,
|
542 |
username, password, local_storage, exampleid_btn, category_btn, concept_btn,
|
543 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
544 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
|
|
584 |
"excluded": gr.State(value=False),
|
585 |
"concepts_dict": gr.State(value=concepts_dict),
|
586 |
"country_lang_map": gr.State(value=lang2eng_mapping),
|
587 |
+
"client": supabase_user_client,
|
588 |
# "is_blurred": is_blurred
|
589 |
+
"vlm_caption": vlm_output,
|
590 |
+
"vlm_feedback": vlm_feedback
|
591 |
}
|
592 |
# data_outputs = [image_inp, image_url_inp, long_caption_inp,
|
593 |
# country_inp, language_inp, category_btn, concept_btn,
|
|
|
595 |
hf_writer.setup(list(data_outputs.keys()), local_ds_folder = LOCAL_DS_DIRECTORY_PATH)
|
596 |
|
597 |
# STEP 4: Chain save_data, then update_user_data, then re-enable button, hide modal, and clear
|
598 |
+
# submit_btn.click(lambda: Modal(visible=True), None, modal_vlm)
|
599 |
+
submit_btn.click(submit_button_clicked,
|
600 |
+
inputs=[vlm_output],
|
601 |
+
outputs=[modal_vlm, modal_submit])
|
602 |
+
|
603 |
+
# submit_btn.click(partial(submit_button_clicked, save_fn=hf_writer.save,
|
604 |
+
# data_outputs=data_outputs),
|
605 |
+
# inputs=[vlm_output],
|
606 |
+
# outputs=[modal_vlm, image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, exampleid_btn,
|
607 |
+
# category_btn, concept_btn, category_concept_dropdowns[0], category_concept_dropdowns[1],
|
608 |
+
# category_concept_dropdowns[2], category_concept_dropdowns[3], category_concept_dropdowns[4]])
|
609 |
+
|
610 |
+
def wire_submit_chain(button, modal_ui):
|
611 |
+
e = button.click(
|
612 |
+
fn=lambda: Modal(visible=False),
|
613 |
+
outputs=[modal_ui]
|
614 |
+
).success(
|
615 |
+
hf_writer.save,
|
616 |
+
inputs = list(data_outputs.values()),
|
617 |
+
outputs = None,
|
618 |
+
).success(
|
619 |
+
fn=partial(clear_data, "submit"),
|
620 |
+
outputs=[
|
621 |
+
image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, exampleid_btn,
|
622 |
+
category_btn, concept_btn,
|
623 |
+
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
624 |
+
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
625 |
+
],
|
626 |
+
# ).success(enable_submit,
|
627 |
+
# None, [submit_btn]
|
628 |
+
# ).success(lambda: Modal(visible=False),
|
629 |
+
# None, modal_saving
|
630 |
+
# ).success(lambda: Modal(visible=True),
|
631 |
+
# None, modal_data_saved
|
632 |
+
).success(
|
633 |
+
# set loading msg
|
634 |
+
lambda: gr.update(value="**Loading your data, please wait ...**"),
|
635 |
+
None, loading_msg
|
636 |
+
).success(
|
637 |
+
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
|
638 |
+
inputs=[supabase_user_client, country_choice, language_choice],
|
639 |
+
outputs=[user_examples, loading_msg, vlm_captions]
|
640 |
+
)
|
641 |
+
return e
|
642 |
+
|
643 |
+
wire_submit_chain(vlm_done_btn, modal_vlm)
|
644 |
+
wire_submit_chain(vlm_no_btn, modal_vlm)
|
645 |
+
wire_submit_chain(submit_yes, modal_submit)
|
646 |
+
submit_no.click(lambda: Modal(visible=False), None, modal_submit)
|
647 |
+
vlm_cancel_btn.click(lambda: Modal(visible=False), None, modal_vlm)
|
648 |
# ============================================ #
|
649 |
# instructions button
|
650 |
instruct_btn.click(lambda: Modal(visible=True), None, modal)
|
|
|
687 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
688 |
category_concept_dropdowns[3], category_concept_dropdowns[4],
|
689 |
timestamp_btn, username_inp, password_inp, exampleid_btn, gr.State(value=True),
|
690 |
+
gr.State(value=concepts_dict), gr.State(value=lang2eng_mapping), vlm_output, vlm_feedback
|
691 |
],
|
692 |
outputs=None
|
693 |
).success(
|
694 |
fn=partial(clear_data, "remove"),
|
695 |
+
outputs=[
|
696 |
+
image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, exampleid_btn,
|
697 |
category_btn, concept_btn,
|
698 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
699 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
|
|
706 |
outputs=loading_msg
|
707 |
).success(
|
708 |
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path=LOCAL_DS_DIRECTORY_PATH),
|
709 |
+
inputs=[supabase_user_client, country_choice, language_choice],
|
710 |
+
outputs=[user_examples, loading_msg, vlm_captions]
|
711 |
+
)
|
712 |
+
# ============================================= #
|
713 |
+
# VLM Gen button
|
714 |
+
# ============================================= #
|
715 |
+
gen_button.click(
|
716 |
+
fn=generate_vlm_caption, # processor=processor, model=model
|
717 |
+
inputs=[image_inp, vlm_model_dropdown],
|
718 |
+
outputs=[vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button]
|
719 |
+
)
|
720 |
+
# vlm_output.change(
|
721 |
+
# fn=lambda : gr.update(interactive=False) if vlm_output.value else gr.update(interactive=True),
|
722 |
+
# inputs=[],
|
723 |
+
# outputs=[gen_button]
|
724 |
+
# )
|
725 |
+
|
726 |
+
ui.load(
|
727 |
+
fn=login_user_recovery,
|
728 |
+
inputs=gr.Textbox(visible=False, value=""), # hidden textbox to get the url tokens
|
729 |
+
outputs=[supabase_user_client, persistent_session, login_status],
|
730 |
+
js=js_code
|
731 |
+
).then(
|
732 |
+
fn=restore_user_session,
|
733 |
+
inputs=[persistent_session],
|
734 |
+
outputs=[supabase_user_client, persistent_session, login_status, proceed_btn, login_btn, sign_up_btn, reset_password_btn, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status],
|
735 |
)
|
736 |
+
return ui
|
737 |
|
|
ui/main_page.py
CHANGED
@@ -107,7 +107,36 @@ def build_main_page(concepts_dict, metadata_dict, local_storage):
|
|
107 |
long_caption_inp = gr.Textbox(lines=6, label="Description", elem_id="long_caption_inp")
|
108 |
num_words_inp = gr.Textbox(lines=1, label="Number of words", elem_id="num_words", interactive=False, value=0)
|
109 |
# num_words_inp = gr.Markdown("Number of words", elem_id="num_words")
|
|
|
|
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
categories_list = sort_with_pyuca(list(concepts_dict["USA"]["English"].keys()))
|
112 |
|
113 |
def create_category_dropdown(category, index):
|
@@ -226,5 +255,16 @@ def build_main_page(concepts_dict, metadata_dict, local_storage):
|
|
226 |
"modal_exclude_confirm": modal_exclude_confirm,
|
227 |
"cancel_exclude_btn": cancel_exclude_btn,
|
228 |
"confirm_exclude_btn": confirm_exclude_btn,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
}
|
230 |
return output_dict
|
|
|
107 |
long_caption_inp = gr.Textbox(lines=6, label="Description", elem_id="long_caption_inp")
|
108 |
num_words_inp = gr.Textbox(lines=1, label="Number of words", elem_id="num_words", interactive=False, value=0)
|
109 |
# num_words_inp = gr.Markdown("Number of words", elem_id="num_words")
|
110 |
+
#########################################################
|
111 |
+
with Modal(visible=False, allow_user_close=False) as modal_vlm:
|
112 |
|
113 |
+
question = gr.Markdown("Would you like to see if a VLM can generate a culturally aware description for your uploaded concept?")
|
114 |
+
with gr.Row():
|
115 |
+
gen_button = gr.Button("Yes", variant="primary", elem_id="generate_answer_btn")
|
116 |
+
vlm_no_btn = gr.Button("No")
|
117 |
+
vlm_cancel_btn = gr.Button("Cancel")
|
118 |
+
vlm_model_dropdown = gr.Dropdown(
|
119 |
+
["SmolVLM-500M", "Qwen2.5-VL-7B", "InternVL3_5-8B", "Gemma3-4B"], value="Gemma3-4B", multiselect=False, label="VLM Model", info="Select the VLM model to use for generating the description."
|
120 |
+
)
|
121 |
+
vlm_output = gr.Textbox(lines=6, label="Generated description", elem_id="vlm_output", interactive=False)
|
122 |
+
vlm_feedback = gr.Radio(["Yes π", "No π"], label="Do you think the generated description is accurate within the cultural context of your country?", visible=False, elem_id="vlm_feedback", interactive=True)
|
123 |
+
vlm_done_btn = gr.Button("Complete Submission", visible=False)
|
124 |
+
|
125 |
+
with Modal(visible=False, allow_user_close=False) as modal_submit:
|
126 |
+
|
127 |
+
gr.Markdown("β οΈ You've already generated a caption for this image. An optional description with the VLM can only be generated once. Would you like to proceed and submit your modified data?")
|
128 |
+
with gr.Row():
|
129 |
+
submit_yes = gr.Button("Yes", variant="primary", elem_id="submit_confirm_yes")
|
130 |
+
submit_no = gr.Button("No", variant="stop", elem_id="submit_confirm_no")
|
131 |
+
|
132 |
+
# with gr.Group():
|
133 |
+
# gr.Markdown("### VLM Generation (Optional)")
|
134 |
+
# with gr.Accordion("π Click here if you want to get a generated answer from a small vlm", open=False):
|
135 |
+
# gen_button = gr.Button("Generate Answer", variant="primary", elem_id="generate_answer_btn")
|
136 |
+
# vlm_output = gr.Textbox(lines=6, label="Generated Answer", elem_id="vlm_output", interactive=False)
|
137 |
+
# vlm_feedback = gr.Radio(["Yes π", "No π"], label="Do you like the generated caption?", visible=False, elem_id="vlm_feedback", interactive=True)
|
138 |
+
##########################################################
|
139 |
+
|
140 |
categories_list = sort_with_pyuca(list(concepts_dict["USA"]["English"].keys()))
|
141 |
|
142 |
def create_category_dropdown(category, index):
|
|
|
255 |
"modal_exclude_confirm": modal_exclude_confirm,
|
256 |
"cancel_exclude_btn": cancel_exclude_btn,
|
257 |
"confirm_exclude_btn": confirm_exclude_btn,
|
258 |
+
"vlm_output": vlm_output,
|
259 |
+
"gen_button": gen_button,
|
260 |
+
"vlm_feedback": vlm_feedback,
|
261 |
+
"modal_vlm": modal_vlm,
|
262 |
+
"vlm_no_btn": vlm_no_btn,
|
263 |
+
"vlm_done_btn": vlm_done_btn,
|
264 |
+
"submit_yes": submit_yes,
|
265 |
+
"submit_no": submit_no,
|
266 |
+
"modal_submit": modal_submit,
|
267 |
+
"vlm_cancel_btn": vlm_cancel_btn,
|
268 |
+
"vlm_model_dropdown": vlm_model_dropdown
|
269 |
}
|
270 |
return output_dict
|
ui/selection_page.py
CHANGED
@@ -57,6 +57,24 @@ def build_selection_page(metadata_dict):
|
|
57 |
username = gr.Textbox(label="Email (optional)", type="email", elem_id="username_text")
|
58 |
password = gr.Textbox(label="Password (optional)", type="password", elem_id="password_text")
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
username = gr.Textbox(label="Email (optional)", type="email", elem_id="username_text")
|
58 |
password = gr.Textbox(label="Password (optional)", type="password", elem_id="password_text")
|
59 |
|
60 |
+
with gr.Row():
|
61 |
+
login_btn = gr.Button("Login", elem_id="login_btn")
|
62 |
+
sign_up_btn = gr.Button("Sign up", elem_id="sign_up_btn")
|
63 |
+
reset_password_btn = gr.Button("Reset Password", elem_id="reset_password_btn")
|
64 |
+
logout_btn = gr.Button("Logout", elem_id="logout_btn",visible=False)
|
65 |
|
66 |
+
login_status = gr.Markdown("")
|
67 |
+
with gr.Row():
|
68 |
+
proceed_btn = gr.Button("Proceed")
|
69 |
+
with gr.Row():
|
70 |
+
change_password_field = gr.Textbox(
|
71 |
+
label="Change Password", type="password", elem_id="change_password_field", visible=True
|
72 |
+
)
|
73 |
+
change_password_field_confirm = gr.Textbox(
|
74 |
+
label="Confirm New Password", type="password", elem_id="change_password_field_confirm", visible=True
|
75 |
+
)
|
76 |
+
with gr.Row():
|
77 |
+
change_password_btn = gr.Button("Change Password", elem_id="change_password_btn", visible=True)
|
78 |
+
change_password_status = gr.Markdown("")
|
79 |
+
|
80 |
+
return selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown, login_btn, sign_up_btn, reset_password_btn, login_status, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status
|