carlosh93 commited on
Commit
ed8368e
Β·
1 Parent(s): 0bd89b7

updating new version with supabase and vlm

Browse files
app.py CHANGED
@@ -1,6 +1,8 @@
1
  # import spacy.cli
2
  # spacy.cli.download("ja_core_news_sm")
3
  # spacy.cli.download("zh_core_web_sm")
 
 
4
  import spacy_udpipe
5
  spacy_udpipe.download("ja")
6
  spacy_udpipe.download("zh")
@@ -16,7 +18,7 @@ metadata = load_metadata()
16
 
17
  demo = build_ui(concepts, metadata, HF_API_TOKEN, HF_DATASET_NAME)
18
  # demo.launch()
19
- demo.launch(debug=False)
20
 
21
  demo.close()
22
  # gr.close_all()
 
1
  # import spacy.cli
2
  # spacy.cli.download("ja_core_news_sm")
3
  # spacy.cli.download("zh_core_web_sm")
4
+ import os
5
+ os.environ["TF_USE_LEGACY_KERAS"] = "1"
6
  import spacy_udpipe
7
  spacy_udpipe.download("ja")
8
  spacy_udpipe.download("zh")
 
18
 
19
  demo = build_ui(concepts, metadata, HF_API_TOKEN, HF_DATASET_NAME)
20
  # demo.launch()
21
+ demo.launch(debug=False, server_port=7861)
22
 
23
  demo.close()
24
  # gr.close_all()
config/settings.py CHANGED
@@ -3,6 +3,9 @@ import os
3
 
4
  load_dotenv()
5
 
6
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
7
  HF_DATASET_NAME = os.getenv("HF_DATASET_NAME")
8
- LOCAL_DS_DIRECTORY_PATH = os.getenv("LOCAL_DS_DIRECTORY_PATH")
 
 
 
 
3
 
4
  load_dotenv()
5
 
6
+ HF_API_TOKEN = os.getenv("HF_TOKEN")
7
  HF_DATASET_NAME = os.getenv("HF_DATASET_NAME")
8
+ LOCAL_DS_DIRECTORY_PATH = os.getenv("LOCAL_DS_DIRECTORY_PATH")
9
+ SUPABASE_URL: str = os.getenv("SUPABASE_URL")
10
+ SUPABASE_KEY: str = os.getenv("SUPABASE_KEY")
11
+ REDIRECT_TO_URL: str = os.getenv("REDIRECT_TO_URL")
logic/data_utils.py CHANGED
@@ -8,7 +8,7 @@ import uuid
8
  import gradio as gr
9
  from PIL import Image
10
  import numpy as np
11
-
12
 
13
  def load_concepts(path="data/concepts.json"):
14
  with open(path, encoding='utf-8') as f:
@@ -53,6 +53,9 @@ class CustomHFDatasetSaver:
53
  self.local_ds_folder = local_ds_folder
54
  os.makedirs(self.local_ds_folder, exist_ok=True)
55
 
 
 
 
56
  self.data_outputs = data_outputs # list of components to read values from
57
 
58
  # create scheduler to commit the data to the hub every x minutes
@@ -63,6 +66,27 @@ class CustomHFDatasetSaver:
63
  every=1,
64
  token=self.api_token,
65
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def validate_data(self, values_dic):
68
  """
@@ -166,10 +190,16 @@ class CustomHFDatasetSaver:
166
  values_dic["id"] = f'{country}_{language}_{category}_{concept}_{current_timestamp}'
167
 
168
  #prepare the main directory of the sample
169
- if values_dic.get("username"):
170
- sample_dir = os.path.join("logged_in_users", values_dic["country"], values_dic["language"], values_dic["username"], str(current_timestamp))
 
 
 
 
 
171
  else:
172
  sample_dir = os.path.join("anonymous_users", values_dic["country"], values_dic["language"], str(uuid.uuid4()), str(current_timestamp))
 
173
 
174
  os.makedirs(os.path.join(self.local_ds_folder, sample_dir), exist_ok=True)
175
 
@@ -217,6 +247,8 @@ class CustomHFDatasetSaver:
217
  # "image_file": image_file_path_on_hub,
218
  "image_url": values_dic['image_url'] or "",
219
  "caption": values_dic['caption'] or "",
 
 
220
  "country": values_dic['country'] or "",
221
  "language": values_dic['language'] or "",
222
  "category": values_dic['category'] or "",
@@ -227,7 +259,7 @@ class CustomHFDatasetSaver:
227
  "category_4_concepts": values_dic.get('category_4_concepts') or [""],
228
  "category_5_concepts": values_dic.get('category_5_concepts') or [""],
229
  "timestamp": current_timestamp,
230
- "username": values_dic['username'] or "",
231
  "password": values_dic['password'] or "",
232
  "id": values_dic['id'],
233
  "excluded": False if values_dic.get('excluded') is None else bool(values_dic.get('excluded')),
 
8
  import gradio as gr
9
  from PIL import Image
10
  import numpy as np
11
+ from logic.supabase_client import auth_handler
12
 
13
  def load_concepts(path="data/concepts.json"):
14
  with open(path, encoding='utf-8') as f:
 
53
  self.local_ds_folder = local_ds_folder
54
  os.makedirs(self.local_ds_folder, exist_ok=True)
55
 
56
+ # Migrate any existing JSON files to include new VLM fields
57
+ self._migrate_existing()
58
+
59
  self.data_outputs = data_outputs # list of components to read values from
60
 
61
  # create scheduler to commit the data to the hub every x minutes
 
66
  every=1,
67
  token=self.api_token,
68
  )
69
+
70
+ def _migrate_existing(self):
71
+ """
72
+ Ensure all existing JSON sample files have the same schema
73
+ by adding missing keys for 'vlm_caption' and 'vlm_feedback'.
74
+ """
75
+ for root, _, files in os.walk(self.local_ds_folder):
76
+ for fname in files:
77
+ if fname.endswith('.json'):
78
+ fpath = os.path.join(root, fname)
79
+ with open(fpath, 'r+', encoding='utf-8') as f:
80
+ data = json.load(f)
81
+ updated = False
82
+ for key in ['vlm_caption', 'vlm_feedback']:
83
+ if key not in data:
84
+ data[key] = ""
85
+ updated = True
86
+ if updated:
87
+ f.seek(0)
88
+ json.dump(data, f, indent=2)
89
+ f.truncate()
90
 
91
  def validate_data(self, values_dic):
92
  """
 
190
  values_dic["id"] = f'{country}_{language}_{category}_{concept}_{current_timestamp}'
191
 
192
  #prepare the main directory of the sample
193
+ # here we check if the user is logged in or not
194
+ user_info = auth_handler.is_logged_in(values_dic.get("client", None))
195
+ print(f"User info: {user_info}")
196
+ if user_info['success']:
197
+ # sample_dir = os.path.join("logged_in_users", values_dic["country"], values_dic["language"], values_dic["username"], str(current_timestamp))
198
+ sample_dir = os.path.join("logged_in_users", values_dic["country"], values_dic["language"], user_info['email'], str(current_timestamp))
199
+ print(f"Sample directory for logged in user: {sample_dir}")
200
  else:
201
  sample_dir = os.path.join("anonymous_users", values_dic["country"], values_dic["language"], str(uuid.uuid4()), str(current_timestamp))
202
+ print(f"Sample directory: {sample_dir}")
203
 
204
  os.makedirs(os.path.join(self.local_ds_folder, sample_dir), exist_ok=True)
205
 
 
247
  # "image_file": image_file_path_on_hub,
248
  "image_url": values_dic['image_url'] or "",
249
  "caption": values_dic['caption'] or "",
250
+ "vlm_caption": values_dic['vlm_caption'] or "",
251
+ "vlm_feedback": values_dic['vlm_feedback'] or "",
252
  "country": values_dic['country'] or "",
253
  "language": values_dic['language'] or "",
254
  "category": values_dic['category'] or "",
 
259
  "category_4_concepts": values_dic.get('category_4_concepts') or [""],
260
  "category_5_concepts": values_dic.get('category_5_concepts') or [""],
261
  "timestamp": current_timestamp,
262
+ "username": user_info['email'] if user_info['success'] else "",
263
  "password": values_dic['password'] or "",
264
  "id": values_dic['id'],
265
  "excluded": False if values_dic.get('excluded') is None else bool(values_dic.get('excluded')),
logic/handlers.py CHANGED
@@ -4,6 +4,7 @@ import io
4
  import PIL
5
  import requests
6
  from typing import Literal
 
7
 
8
  from datasets import load_dataset, concatenate_datasets, Image
9
  from data.lang2eng_map import lang2eng_mapping
@@ -12,6 +13,7 @@ import gradio as gr
12
  import bcrypt
13
  from config.settings import HF_API_TOKEN
14
  from huggingface_hub import snapshot_download
 
15
  # from .blur import blur_faces, detect_faces
16
  from retinaface import RetinaFace
17
  from gradio_modal import Modal
@@ -71,13 +73,13 @@ def clear_data(message: Literal["submit", "remove"] | None = None):
71
  gr.Info("If you logged in, you will soon see it at the bottom of the page, where you can edit it or delete it", title="Thank you for submitting your data! πŸŽ‰", duration=5)
72
  elif message == "remove":
73
  gr.Info("", title="Your data has been deleted! πŸ—‘οΈ", duration=5)
74
- return (None, None, None, None, None, gr.update(value=None),
75
  gr.update(value=[]), gr.update(value=[]), gr.update(value=[]),
76
  gr.update(value=[]), gr.update(value=[]))
77
 
78
 
79
  def exit():
80
- return (None, None, None, gr.Dataset(samples=[]), gr.Markdown("**Loading your data, please wait ...**"),
81
  gr.update(value=None), gr.update(value=None), [None, None, "", ""], gr.update(value=None),
82
  gr.update(value=None), gr.update(value=None),
83
  gr.update(value=None), gr.update(value=None), gr.update(value=None),
@@ -87,9 +89,8 @@ def exit():
87
  def validate_metadata(country, language):
88
  # Perform your validation logic here
89
  if country is None or language is None:
90
- return gr.Button("Proceed", interactive=False)
91
-
92
- return gr.Button("Proceed", interactive=True)
93
 
94
 
95
  def validate_inputs(image, ori_img, concept): # is_blurred
@@ -129,6 +130,30 @@ def validate_inputs(image, ori_img, concept): # is_blurred
129
 
130
  return gr.Button("Submit", variant="primary", interactive=True), result_image, ori_img # is_blurred
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  def count_words(caption, language):
134
  match language:
@@ -152,8 +177,14 @@ def add_prefix(example, column_name, prefix):
152
  example[column_name] = (f"{prefix}/" + example[column_name])
153
  return example
154
 
155
- def update_user_data(username, password, country, language_choice, HF_DATASET_NAME, local_ds_directory_path):
 
 
 
 
 
156
 
 
157
  datasets_list = []
158
  # Try loading local dataset
159
  try:
@@ -191,18 +222,19 @@ def update_user_data(username, password, country, language_choice, HF_DATASET_NA
191
  # Handle all empty
192
  if not datasets_list:
193
  if username: # User is logged in but has no data
194
- return gr.Dataset(samples=[]), gr.Markdown("<p style='color: red;'>No data available for this user. Please upload an image.</p>")
195
  else: # No user logged in
196
- return gr.Dataset(samples=[]), gr.Markdown("")
197
 
198
  dataset = concatenate_datasets(datasets_list)
199
  # TODO: we should link username with password and language and country, otherwise there will be an error when loading with different language and clicking on the example
200
- if username and password:
201
- user_dataset = dataset.filter(lambda x: x['username'] == username and is_password_correct(x['password'], password))
202
  user_dataset = user_dataset.sort('timestamp', reverse=True)
203
  # Show only unique entries (most recent)
204
  user_ids = set()
205
  samples = []
 
206
  for d in user_dataset:
207
  if d['id'] in user_ids:
208
  continue
@@ -229,6 +261,10 @@ def update_user_data(username, password, country, language_choice, HF_DATASET_NA
229
  d['image_file'], d['image_url'], d['caption'] or "", d['country'],
230
  d['language'], d['category'], d['concept'], additional_concepts_by_category, d['id']] # d['is_blurred']
231
  )
 
 
 
 
232
  # return gr.Dataset(samples=samples), None
233
  # ───────────────────────────────────────────────────
234
  # Clean up the β€œAdditional Concepts” column (index 7)
@@ -255,10 +291,14 @@ def update_user_data(username, password, country, language_choice, HF_DATASET_NA
255
  row_copy[7] = ", ".join(vals)
256
  cleaned.append(row_copy)
257
 
258
- return gr.Dataset(samples=cleaned), None
 
 
 
 
259
  else:
260
  # TODO: should we show the entire dataset instead? What about "other data" tab?
261
- return gr.Dataset(samples=[]), None
262
 
263
 
264
  def update_language(local_storage, metadata_dict, concepts_dict):
@@ -357,7 +397,7 @@ def update_intro_language(selected_country, selected_language, intro_markdown, m
357
  return gr.Markdown(INTRO_TEXT)
358
 
359
 
360
- def handle_click_example(user_examples, concepts_dict):
361
  # print("handle_click_example")
362
  # print(user_examples)
363
  # ex = [item for item in user_examples]
@@ -365,7 +405,6 @@ def handle_click_example(user_examples, concepts_dict):
365
  # 1) Turn the flat string in slot 7 back into a list-of-lists
366
  ex = list(user_examples)
367
  raw_ac = ex[7] if len(ex) > 7 else ""
368
-
369
  country_btn = ex[3]
370
  language_btn = ex[4]
371
  concepts = concepts_dict[country_btn][language_btn]
@@ -441,7 +480,13 @@ def handle_click_example(user_examples, concepts_dict):
441
  # dropdown_values.append(None)
442
 
443
  # Need to return values for each category dropdown
444
- return [image_inp, image_url_inp, long_caption_inp, exampleid_btn, category_btn, concept_btn] + additional_concepts_by_category + [True]
 
 
 
 
 
 
445
  # return [
446
  # image_inp,
447
  # image_url_inp,
@@ -535,8 +580,8 @@ def blur_selected_faces(image, blur_faces_ids, faces_info, face_img, faces_count
535
  parsed_faces_ids = [f"face_{val.split(':')[-1].strip()}" for val in parsed_faces_ids]
536
 
537
  # Base blur amount and bounds
538
- MIN_BLUR = 31 # Minimum blur amount (must be odd)
539
- MAX_BLUR = 131 # Maximum blur amount (must be odd)
540
 
541
  blurring_start = time.time()
542
  # Process each face
@@ -688,4 +733,28 @@ def check_exclude_fn(image):
688
 
689
  def has_user_json(username, country,language_choice, local_ds_directory_path):
690
  """Check if JSON files exist for username pattern."""
691
- return bool(glob.glob(os.path.join(local_ds_directory_path, "logged_in_users", country, language_choice, username, "**", "*.json"), recursive=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import PIL
5
  import requests
6
  from typing import Literal
7
+ from logic.supabase_client import auth_handler
8
 
9
  from datasets import load_dataset, concatenate_datasets, Image
10
  from data.lang2eng_map import lang2eng_mapping
 
13
  import bcrypt
14
  from config.settings import HF_API_TOKEN
15
  from huggingface_hub import snapshot_download
16
+ from logic.vlm import vlm_manager
17
  # from .blur import blur_faces, detect_faces
18
  from retinaface import RetinaFace
19
  from gradio_modal import Modal
 
73
  gr.Info("If you logged in, you will soon see it at the bottom of the page, where you can edit it or delete it", title="Thank you for submitting your data! πŸŽ‰", duration=5)
74
  elif message == "remove":
75
  gr.Info("", title="Your data has been deleted! πŸ—‘οΈ", duration=5)
76
+ return (None, None, None, gr.update(value=None), gr.update(value=None, visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True), None, None, gr.update(value=None),
77
  gr.update(value=[]), gr.update(value=[]), gr.update(value=[]),
78
  gr.update(value=[]), gr.update(value=[]))
79
 
80
 
81
  def exit():
82
+ return (None, None, None, gr.update(value=None), gr.update(value=None, visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True), gr.Dataset(samples=[]), gr.Markdown("**Loading your data, please wait ...**"),
83
  gr.update(value=None), gr.update(value=None), [None, None, "", ""], gr.update(value=None),
84
  gr.update(value=None), gr.update(value=None),
85
  gr.update(value=None), gr.update(value=None), gr.update(value=None),
 
89
  def validate_metadata(country, language):
90
  # Perform your validation logic here
91
  if country is None or language is None:
92
+ return gr.update(interactive=False)
93
+ return gr.update(interactive=True)
 
94
 
95
 
96
  def validate_inputs(image, ori_img, concept): # is_blurred
 
130
 
131
  return gr.Button("Submit", variant="primary", interactive=True), result_image, ori_img # is_blurred
132
 
133
+ def generate_vlm_caption(image, model_name="SmolVLM-500M"): # processor, model
134
+ """
135
+ Generate a caption for the given image using a Vision-Language Model.
136
+ Uses the global VLMManager for efficient model loading and caching.
137
+ """
138
+ if image is None:
139
+ gr.Warning("⚠️ Please upload an image first.", duration=5)
140
+ return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
141
+
142
+ try:
143
+ # Use the global VLMManager to load/get the model
144
+ vlm_manager.load_model(model_name)
145
+ caption = vlm_manager.generate_caption(image)
146
+ except Exception as e:
147
+ print(f"Error generating caption: {e}. Cleaning up memory and try again.")
148
+ gr.Warning(f"⚠️ Error generating caption: {e} due to memory issues. Please try again.", duration=5)
149
+ # vlm_manager.cleanup_memory()
150
+ return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
151
+ finally: # For now, let's cleanup memory after each generation
152
+ vlm_manager.cleanup_memory()
153
+
154
+ # print(caption)
155
+
156
+ return caption, gr.update(visible=True), gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False)
157
 
158
  def count_words(caption, language):
159
  match language:
 
177
  example[column_name] = (f"{prefix}/" + example[column_name])
178
  return example
179
 
180
+ def update_user_data(client , country, language_choice, HF_DATASET_NAME, local_ds_directory_path):
181
+ user_info = auth_handler.is_logged_in(client)
182
+ print(f"User info: {user_info}")
183
+ if not user_info['success']:
184
+ print("User is not logged in or session expired.")
185
+ return gr.Dataset(samples=[]), None, None
186
 
187
+ username = user_info['email']
188
  datasets_list = []
189
  # Try loading local dataset
190
  try:
 
222
  # Handle all empty
223
  if not datasets_list:
224
  if username: # User is logged in but has no data
225
+ return gr.Dataset(samples=[]), gr.Markdown("<p style='color: red;'>No data available for this user. Please upload an image.</p>"), None
226
  else: # No user logged in
227
+ return gr.Dataset(samples=[]), gr.Markdown(""), None
228
 
229
  dataset = concatenate_datasets(datasets_list)
230
  # TODO: we should link username with password and language and country, otherwise there will be an error when loading with different language and clicking on the example
231
+ if username:
232
+ user_dataset = dataset.filter(lambda x: x['username'] == username)
233
  user_dataset = user_dataset.sort('timestamp', reverse=True)
234
  # Show only unique entries (most recent)
235
  user_ids = set()
236
  samples = []
237
+ vlm_captions = dict()
238
  for d in user_dataset:
239
  if d['id'] in user_ids:
240
  continue
 
261
  d['image_file'], d['image_url'], d['caption'] or "", d['country'],
262
  d['language'], d['category'], d['concept'], additional_concepts_by_category, d['id']] # d['is_blurred']
263
  )
264
+
265
+ if 'vlm_caption' in d:
266
+ vlm_captions[d['id']] = d.get('vlm_caption', "")
267
+
268
  # return gr.Dataset(samples=samples), None
269
  # ───────────────────────────────────────────────────
270
  # Clean up the β€œAdditional Concepts” column (index 7)
 
291
  row_copy[7] = ", ".join(vals)
292
  cleaned.append(row_copy)
293
 
294
+ # check if vlm_captions is an empty dictionary
295
+ if not vlm_captions:
296
+ vlm_captions = None
297
+
298
+ return gr.Dataset(samples=cleaned), None, vlm_captions
299
  else:
300
  # TODO: should we show the entire dataset instead? What about "other data" tab?
301
+ return gr.Dataset(samples=[]), None, None
302
 
303
 
304
  def update_language(local_storage, metadata_dict, concepts_dict):
 
397
  return gr.Markdown(INTRO_TEXT)
398
 
399
 
400
+ def handle_click_example(user_examples, vlm_captions, concepts_dict):
401
  # print("handle_click_example")
402
  # print(user_examples)
403
  # ex = [item for item in user_examples]
 
405
  # 1) Turn the flat string in slot 7 back into a list-of-lists
406
  ex = list(user_examples)
407
  raw_ac = ex[7] if len(ex) > 7 else ""
 
408
  country_btn = ex[3]
409
  language_btn = ex[4]
410
  concepts = concepts_dict[country_btn][language_btn]
 
480
  # dropdown_values.append(None)
481
 
482
  # Need to return values for each category dropdown
483
+
484
+ vlm_caption = None
485
+ if vlm_captions:
486
+ if exampleid_btn in vlm_captions:
487
+ vlm_caption = vlm_captions[exampleid_btn]
488
+
489
+ return [image_inp, image_url_inp, long_caption_inp, exampleid_btn, category_btn, concept_btn] + additional_concepts_by_category + [True] + [vlm_caption] # loading_example flag + vlm_caption
490
  # return [
491
  # image_inp,
492
  # image_url_inp,
 
580
  parsed_faces_ids = [f"face_{val.split(':')[-1].strip()}" for val in parsed_faces_ids]
581
 
582
  # Base blur amount and bounds
583
+ MIN_BLUR = 131 # Minimum blur amount (must be odd)
584
+ MAX_BLUR = 351 # Maximum blur amount (must be odd)
585
 
586
  blurring_start = time.time()
587
  # Process each face
 
733
 
734
  def has_user_json(username, country,language_choice, local_ds_directory_path):
735
  """Check if JSON files exist for username pattern."""
736
+ return bool(glob.glob(os.path.join(local_ds_directory_path, "logged_in_users", country, language_choice, username, "**", "*.json"), recursive=True))
737
+
738
+ def submit_button_clicked(vlm_output):
739
+
740
+ if vlm_output is None or vlm_output == '':
741
+ return Modal(visible=True), Modal(visible=False)
742
+ else:
743
+ return Modal(visible=False), Modal(visible=True)
744
+ # def submit_button_clicked(vlm_output, save_fn, data_outputs):
745
+ # if vlm_output is None:
746
+ # return Modal(visible=True)
747
+ # else:
748
+ # try:
749
+ # save_fn(list(data_outputs.values()))
750
+ # except Exception as e:
751
+ # gr.Error(f"⚠️ Error saving data: {e}")
752
+
753
+ # try:
754
+ # image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, exampleid_btn, category_btn, concept_btn, \
755
+ # category_concept_dropdowns0, category_concept_dropdowns1, category_concept_dropdowns2, category_concept_dropdowns3, \
756
+ # category_concept_dropdowns4 = clear_data("submit")
757
+ # except Exception as e:
758
+ # gr.Error(f"⚠️ Error clearing data: {e}")
759
+
760
+ # return Modal(visible=False)
logic/supabase_client.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from supabase import create_client, Client
3
+ import os
4
+ from config.settings import SUPABASE_URL, SUPABASE_KEY, REDIRECT_TO_URL
5
+ import traceback
6
+ from supabase.lib.client_options import ClientOptions
7
+
8
+
9
+ # --- Supabase Authentication Class ---
10
+
11
+ class SupabaseAuth:
12
+ """A class to handle Supabase authentication logic."""
13
+ def __init__(self, url: str, key: str):
14
+ self.url = url
15
+ self.key = key
16
+ try:
17
+ self.client: Client = create_client(url, key)
18
+ except Exception as e:
19
+ print(f"Error creating Supabase client: {e}")
20
+ self.client = None
21
+
22
+ def login(self, email: str, password: str):
23
+ """
24
+ Attempts to log in a user and returns a user-specific client.
25
+ """
26
+ if not self.client:
27
+ return {'success': False, 'data': None, 'message': "Supabase client not initialized."}
28
+ try:
29
+ response = self.client.auth.sign_in_with_password({"email": email, "password": password})
30
+ user_session = response.session
31
+
32
+ # Create a new, authenticated client for this user
33
+ authenticated_client = create_client(
34
+ self.url,
35
+ self.key,
36
+ # options={"headers": {"Authorization": f"Bearer {user_session.access_token}"}}
37
+ options=ClientOptions(
38
+ headers={"Authorization": f"Bearer {user_session.access_token}"},
39
+ )
40
+ )
41
+ authenticated_client.auth.set_session(user_session.access_token, user_session.refresh_token)
42
+
43
+ session_data = {
44
+ "refresh_token": user_session.refresh_token,
45
+ "user_email": user_session.user.email,
46
+ "client": authenticated_client
47
+ }
48
+ return {'success': True, 'data': session_data, 'message': f"Welcome, {user_session.user.email}!"}
49
+ except Exception as e:
50
+ # print(f"Error logging in: {e}")
51
+ # traceback.print_exc()
52
+ # Handle specific error messages for better user feedback
53
+ return {'success': False, 'data': None, 'message': f"Login failed: {e}"}
54
+
55
+ def sign_up(self, email: str, password: str):
56
+ """Signs up a new user."""
57
+ if not self.client:
58
+ return {'success': False, 'message': "Supabase client not initialized."}
59
+ try:
60
+ # Supabase sign_up returns a session if email confirmation is disabled,
61
+ # or just a user object if it's enabled. We'll just return a success message.
62
+ self.client.auth.sign_up({
63
+ "email": email,
64
+ "password": password,
65
+ })
66
+ return {'success': True, 'message': 'Sign up successful! You can login now.'}
67
+ except Exception as e:
68
+ return {'success': False, 'message': f"Sign up failed: {e}"}
69
+
70
+ def restore_session(self, refresh_token: str):
71
+ """
72
+ Attempts to restore a session using a refresh token.
73
+ """
74
+ if not self.client:
75
+ return {'success': False, 'data': None, 'message': "Supabase client not initialized."}
76
+ try:
77
+ response = self.client.auth.refresh_session(refresh_token)
78
+ user_session = response.session
79
+
80
+ authenticated_client = create_client(
81
+ self.url,
82
+ self.key,
83
+ options=ClientOptions(
84
+ headers={"Authorization": f"Bearer {user_session.access_token}"},
85
+ )
86
+ )
87
+ authenticated_client.auth.set_session(user_session.access_token, user_session.refresh_token)
88
+
89
+ session_data = {
90
+ "refresh_token": user_session.refresh_token,
91
+ "user_email": user_session.user.email,
92
+ "client": authenticated_client
93
+ }
94
+ print("Session restored successfully:", session_data)
95
+ return {'success': True, 'data': session_data, 'message': f"Welcome, {user_session.user.email}!"}
96
+ except Exception as e:
97
+ print("failed to restore session:", e)
98
+ return {'success': False, 'data': None, 'message': f"Failed to restore session: {e}"}
99
+
100
+ def logout(self, user_client: Client):
101
+ """Signs out the user from Supabase, invalidating the token."""
102
+ if not user_client:
103
+ return {'success': False, 'message': 'No user client provided to log out.'}
104
+ try:
105
+ user_client.auth.sign_out()
106
+ return {'success': True, 'message': 'Successfully signed out from Supabase.'}
107
+ except Exception as e:
108
+ # It's often safe to ignore errors here (e.g., if token already expired)
109
+ # but we'll log it for debugging.
110
+ print(f"Error signing out from Supabase: {e}")
111
+ return {'success': False, 'message': f'Error signing out: {e}'}
112
+
113
+ def change_password(self, user_client: Client, new_password: str):
114
+ """Changes the user's password."""
115
+ if not user_client:
116
+ return {'success': False, 'message': 'No user client provided to change password.'}
117
+ try:
118
+ user_client.auth.update_user({"password": new_password})
119
+ return {'success': True, 'message': 'Password changed successfully.'}
120
+ except Exception as e:
121
+ return {'success': False, 'message': f'Error changing password: {e}'}
122
+
123
+ def is_logged_in(self, user_client: Client):
124
+ """Checks if a user is currently authenticated and returns their email."""
125
+ print("Checking if user is logged in...", user_client)
126
+ if not user_client:
127
+ return {'success': False, 'email': None, 'message': 'No user client provided.'}
128
+ try:
129
+ user_response = user_client.auth.get_user()
130
+ user = user_response.user
131
+ if user:
132
+ return {'success': True, 'email': user.email, 'message': f'Logged in as: {user.email}'}
133
+ else:
134
+ return {'success': False, 'email': None, 'message': 'User is not logged in.'}
135
+ except Exception as e:
136
+ # This might happen if the token has expired and can't be refreshed.
137
+ return {'success': False, 'email': None, 'message': f'Authentication check failed: {e}'}
138
+
139
+ def reset_password_for_email(self, email: str):
140
+ """
141
+ Sends a password reset email to the specified address.
142
+ """
143
+ if not self.client:
144
+ return {'success': False, 'message': "Supabase client not initialized."}
145
+ try:
146
+ self.client.auth.reset_password_for_email(
147
+ email,
148
+ {
149
+ "redirect_to": str(REDIRECT_TO_URL),
150
+ }
151
+ )
152
+ return {'success': True, 'message': "Password reset email sent. Check your inbox!"}
153
+ except Exception as e:
154
+ return {'success': False, 'message': f"Failed to send reset email: {e}"}
155
+
156
+ def retrieve_session_from_tokens(self, access_token: str, refresh_token: str):
157
+ """
158
+ Retrieves a session from an access token and refresh token.
159
+ This is typically used after a password recovery link is clicked.
160
+ """
161
+ if not self.client:
162
+ return {'success': False, 'data': None, 'message': "Supabase client not initialized."}
163
+ try:
164
+ # Set the session on the main client to verify tokens and get user info
165
+ self.client.auth.set_session(access_token, refresh_token)
166
+ user_response = self.client.auth.get_user()
167
+ user = user_response.user
168
+
169
+ if not user:
170
+ return {'success': False, 'data': None, 'message': "Could not retrieve user from tokens."}
171
+
172
+ # Create a new, authenticated client for this user, similar to login
173
+ authenticated_client = create_client(
174
+ self.url,
175
+ self.key,
176
+ options=ClientOptions(
177
+ headers={"Authorization": f"Bearer {access_token}"},
178
+ )
179
+ )
180
+ authenticated_client.auth.set_session(access_token, refresh_token)
181
+
182
+ session_data = {
183
+ "refresh_token": refresh_token,
184
+ "user_email": user.email,
185
+ "client": authenticated_client
186
+ }
187
+ return {'success': True, 'data': session_data, 'message': f"Welcome, {user.email}!"}
188
+ except Exception as e:
189
+ return {'success': False, 'data': None, 'message': f"Failed to retrieve session from tokens: {e}"}
190
+
191
+
192
+ auth_handler = SupabaseAuth(SUPABASE_URL, SUPABASE_KEY)
193
+
logic/vlm.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision.transforms as T
2
+ from torchvision.transforms.functional import InterpolationMode
3
+ from PIL import Image
4
+ from transformers import TorchAoConfig, Qwen2_5_VLForConditionalGeneration, Gemma3ForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, AutoModel
5
+ from qwen_vl_utils import process_vision_info
6
+ import gc
7
+ # from transformers.image_utils import load_image
8
+
9
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
10
+ IMAGENET_STD = (0.229, 0.224, 0.225)
11
+
12
+ class VLMManager:
13
+ """
14
+ A manager class for Vision-Language Models that handles model loading,
15
+ caching, and dynamic switching between different models.
16
+ """
17
+
18
+ def __init__(self, default_model: str = "Gemma3-4B"):
19
+ """
20
+ Initialize the VLM Manager with a default model.
21
+
22
+ Args:
23
+ default_model (str): The default model to load initially.
24
+ """
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ self.current_model_name = None
27
+ self.processor = None
28
+ self.tokenizer = None # Initialize tokenizer attribute
29
+ self.model = None
30
+
31
+ self.system_message = """
32
+ You are an expert cultural-aware image-analysis assistant. For every image:
33
+ 1. Output exactly 40 words in total.
34
+ 2. Use a single paragraph (no lists or bullet points).
35
+ 3. Describe Who (appearance/emotion), What (action), and Where (setting).
36
+ 4. Do NOT include opinions or speculations.
37
+ 5. If you go over 40 words, shorten or remove non-essential details.
38
+ """
39
+
40
+ self.user_prompt = """
41
+ Given this image, please provide an image description of around 40 words with extensive and detailed visual information.
42
+
43
+ Descriptions must be objective: focus on how you would describe the image to someone who can't see it, without your own opinions/speculations.
44
+
45
+ The text needs to include the main concept and describe the content of the image in detail by including:
46
+ - Who?: The visual appearance and observable emotions (e.g., "is smiling") of persons and animals.
47
+ - What?: The actions performed in the image.
48
+ - Where?: The setting of the image, including the size, color, and relationships between objects.
49
+ """
50
+
51
+ # Load the default model
52
+ self.load_model(default_model)
53
+
54
+ def load_model(self, model_name: str):
55
+ """
56
+ Load a VLM model. If the model is already loaded, return the cached version.
57
+
58
+ Args:
59
+ model_name (str): The name of the model to load.
60
+ """
61
+ # If the requested model is already loaded, no need to reload
62
+ if self.current_model_name == model_name and self.model is not None:
63
+ print(f"Model {model_name} is already loaded, using cached version.")
64
+ if self.current_model_name == "InternVL3_5-8B":
65
+ return self.tokenizer, self.model
66
+ else:
67
+ return self.processor, self.model
68
+
69
+ print(f"Loading model: {model_name}")
70
+
71
+ # Clear current model from memory if exists
72
+ if self.model is not None:
73
+ del self.model
74
+ self.model = None
75
+ if self.current_model_name == "InternVL3_5-8B":
76
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
77
+ del self.tokenizer
78
+ self.tokenizer = None
79
+ else:
80
+ if hasattr(self, 'processor') and self.processor is not None:
81
+ del self.processor
82
+ self.processor = None
83
+ # Force garbage collection and clear CUDA cache
84
+ gc.collect()
85
+ if torch.cuda.is_available():
86
+ torch.cuda.empty_cache()
87
+ torch.cuda.synchronize() # Wait for all operations to complete
88
+
89
+ # Load the new model
90
+ if model_name == "SmolVLM-500M":
91
+ self.processor, self.model = self._load_smolvlm_model("HuggingFaceTB/SmolVLM-500M-Instruct")
92
+ elif model_name == "Qwen2.5-VL-7B":
93
+ self.processor, self.model = self._load_qwen25_model("Qwen/Qwen2.5-VL-7B-Instruct")
94
+ elif model_name == "InternVL3_5-8B":
95
+ self.tokenizer, self.model = self._load_internvl35_model("OpenGVLab/InternVL3_5-8B-Instruct")
96
+ elif model_name == "Gemma3-4B":
97
+ self.processor, self.model = self._load_gemma3_model("google/gemma-3-4b-it")
98
+ else:
99
+ raise ValueError(f"Model {model_name} is not supported or not available.")
100
+
101
+ self.current_model_name = model_name
102
+ print(f"Successfully loaded model: {model_name}")
103
+
104
+ def generate_caption(self, image):
105
+ """
106
+ Generate a caption for the given image using the loaded model.
107
+
108
+ Args:
109
+ processor: The processor for the model.
110
+ model: The model to use for generating the caption.
111
+ image: The image to generate a caption for.
112
+ """
113
+ if self.current_model_name == "SmolVLM-500M":
114
+ return self._inference_smolvlm_model(image)
115
+ elif self.current_model_name == "Qwen2.5-VL-7B":
116
+ return self._inference_qwen25_model(image)
117
+ elif self.current_model_name == "InternVL3_5-8B":
118
+ return self._inference_internvl35_model(image)
119
+ elif self.current_model_name == "Gemma3-4B":
120
+ return self._inference_gemma3_model(image)
121
+ else:
122
+ raise ValueError(f"Model {self.current_model_name} is not supported or not available.")
123
+
124
+ def get_current_model(self):
125
+ """
126
+ Get the currently loaded model and processor.
127
+
128
+ Returns:
129
+ tuple: A tuple containing (processor, model, model_name).
130
+ """
131
+ return self.processor, self.model, self.current_model_name
132
+
133
+ def cleanup_memory(self):
134
+ """
135
+ Explicit memory cleanup method that can be called to free GPU memory.
136
+ """
137
+ if self.model is not None:
138
+ del self.model
139
+ self.model = None
140
+ if hasattr(self, 'processor') and self.processor is not None:
141
+ del self.processor
142
+ self.processor = None
143
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
144
+ del self.tokenizer
145
+ self.tokenizer = None
146
+
147
+ self.current_model_name = None
148
+
149
+ # Force cleanup
150
+ gc.collect()
151
+ if torch.cuda.is_available():
152
+ torch.cuda.empty_cache()
153
+ torch.cuda.synchronize()
154
+
155
+ print("Memory cleanup completed.")
156
+
157
+ #########################################################
158
+ ## Load functions
159
+
160
+ def _load_smolvlm_model(self, model_name):
161
+ """Load SmolVLM model."""
162
+ processor = AutoProcessor.from_pretrained(model_name)
163
+ model = AutoModelForVision2Seq.from_pretrained(
164
+ model_name,
165
+ _attn_implementation="eager"
166
+ ).to(self.device)
167
+ model.eval()
168
+ return processor, model
169
+
170
+ def _load_qwen25_model(self, model_name):
171
+ """Load Qwen2.5-VL model."""
172
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
173
+ model_name, torch_dtype="auto", device_map="auto"
174
+ )
175
+
176
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
177
+ # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
178
+ # "Qwen/Qwen2.5-VL-7B-Instruct",
179
+ # torch_dtype=torch.bfloat16,
180
+ # attn_implementation="flash_attention_2",
181
+ # device_map="auto",
182
+ # )
183
+
184
+ processor = AutoProcessor.from_pretrained(model_name)
185
+ model.eval()
186
+ return processor, model
187
+
188
+ def _load_internvl35_model(self, model_name):
189
+ """Load InternVL3.5 model."""
190
+ # Load tokenizer (InternVL uses tokenizer instead of processor for text)
191
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
192
+
193
+ # Load the model using AutoModel
194
+ model = AutoModel.from_pretrained(
195
+ model_name,
196
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
197
+ low_cpu_mem_usage=True,
198
+ use_flash_attn=False, # True set False if CUDA mismatch
199
+ trust_remote_code=True,
200
+ device_map="auto"
201
+ )
202
+
203
+ model.eval()
204
+
205
+ # Return tokenizer as processor for consistency with the interface
206
+ return tokenizer, model
207
+
208
+ def _load_gemma3_model(self, model_name):
209
+ """Load Gemma3 model."""
210
+ quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
211
+ model = Gemma3ForConditionalGeneration.from_pretrained(
212
+ model_name,
213
+ device_map="auto",
214
+ quantization_config=quantization_config
215
+ )
216
+ processor = AutoProcessor.from_pretrained(model_name)
217
+ model.eval()
218
+ return processor, model
219
+
220
+ #########################################################
221
+ ## Inference functions
222
+ def check_processor_and_model(self):
223
+ if self.processor is None or self.model is None:
224
+ raise ValueError("Processor and model must be loaded before generating a caption.")
225
+
226
+ def _inference_qwen25_model(self, image):
227
+ """Inference Qwen2.5-VL model."""
228
+ self.check_processor_and_model()
229
+ messages = [
230
+ {
231
+ "role": "system",
232
+ "content": [{"type": "text", "text": self.system_message}]
233
+ },
234
+ {
235
+ "role": "user",
236
+ "content": [
237
+ {
238
+ "type": "image",
239
+ "image": Image.fromarray(image),
240
+ },
241
+ {"type": "text", "text": self.user_prompt},
242
+ ],
243
+ }
244
+ ]
245
+
246
+ # Preparation for inference
247
+ text = self.processor.apply_chat_template(
248
+ messages, tokenize=False, add_generation_prompt=True
249
+ )
250
+ image_inputs, video_inputs = process_vision_info(messages)
251
+ inputs = self.processor(
252
+ text=[text],
253
+ images=image_inputs,
254
+ videos=video_inputs,
255
+ padding=True,
256
+ return_tensors="pt",
257
+ )
258
+ inputs = inputs.to(self.model.device)
259
+
260
+ # Inference: Generation of the output
261
+ generated_ids = self.model.generate(**inputs, max_new_tokens=128)
262
+ generated_ids_trimmed = [
263
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
264
+ ]
265
+ caption = self.processor.batch_decode(
266
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
267
+ )[0]
268
+
269
+ # Clean up tensors to free GPU memory
270
+ del inputs, generated_ids, generated_ids_trimmed
271
+ if torch.cuda.is_available():
272
+ torch.cuda.empty_cache()
273
+
274
+ return caption
275
+
276
+ def _inference_gemma3_model(self, image):
277
+ """Inference Gemma3 model."""
278
+ self.check_processor_and_model()
279
+ messages = [
280
+ {
281
+ "role": "system",
282
+ "content": [{"type": "text", "text": self.system_message}]
283
+ },
284
+ {
285
+ "role": "user",
286
+ "content": [
287
+ {"type": "image", "image": Image.fromarray(image)},
288
+ {"type": "text", "text": self.user_prompt}
289
+ ]
290
+ }
291
+ ]
292
+
293
+ inputs = self.processor.apply_chat_template(
294
+ messages, add_generation_prompt=True, tokenize=True,
295
+ return_dict=True, return_tensors="pt"
296
+ ).to(self.model.device, dtype=torch.bfloat16)
297
+
298
+ input_len = inputs["input_ids"].shape[-1]
299
+
300
+ with torch.inference_mode():
301
+ generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False)
302
+ generation = generation[0][input_len:]
303
+
304
+ caption = self.processor.decode(generation, skip_special_tokens=True)
305
+
306
+ # Clean up tensors to free GPU memory
307
+ del inputs, generation
308
+ if torch.cuda.is_available():
309
+ torch.cuda.empty_cache()
310
+
311
+ return caption
312
+
313
+ def _inference_smolvlm_model(self, image):
314
+ self.check_processor_and_model()
315
+ messages = [
316
+ {
317
+ "role": "system",
318
+ "content": self.system_message
319
+ },
320
+ {
321
+ "role": "user",
322
+ "content": [
323
+ {"type": "image"},
324
+ {"type": "text", "text": self.user_prompt}
325
+ ]
326
+ }
327
+ ]
328
+
329
+ # Prepare inputs
330
+ prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
331
+ inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
332
+ inputs = inputs.to(self.model.device)
333
+
334
+ # Generate outputs
335
+ gen_kwargs = {
336
+ "max_new_tokens": 200, # plenty for ~40 words
337
+ # "early_stopping": True, # stop at first EOS
338
+ # "no_repeat_ngram_size": 3, # discourage loops
339
+ # "length_penalty": 0.8, # slightly favor brevity
340
+ # "eos_token_id": processor.tokenizer.eos_token_id,
341
+ # "pad_token_id": processor.tokenizer.eos_token_id,
342
+ }
343
+ generated_ids = self.model.generate(**inputs, **gen_kwargs) # max_new_tokens=500)
344
+ generated_texts = self.processor.batch_decode(
345
+ generated_ids,
346
+ skip_special_tokens=True,
347
+ )[0]
348
+
349
+ # Extract only what the assistant said
350
+ if "Assistant:" in generated_texts:
351
+ caption = generated_texts.split("Assistant:", 1)[1].strip()
352
+ else:
353
+ caption = generated_texts.strip()
354
+
355
+ # Clean up tensors to free GPU memory
356
+ del inputs, generated_ids
357
+ if torch.cuda.is_available():
358
+ torch.cuda.empty_cache()
359
+
360
+ return caption
361
+
362
+ def _inference_internvl35_model(self, image):
363
+ if self.tokenizer is None:
364
+ raise ValueError("Tokenizer must be loaded before generating a caption for InternVL3.5.")
365
+ # image can be numpy (H,W,3) or PIL.Image
366
+ if hasattr(image, "shape"): # numpy array
367
+ pil_image = Image.fromarray(image.astype("uint8"), mode="RGB")
368
+ else:
369
+ pil_image = image
370
+
371
+ pixel_values = self._image_to_pixel_values(pil_image, size=448, max_num=12)
372
+ pixel_values = pixel_values.to(dtype=torch.bfloat16, device=self.model.device)
373
+
374
+ # Format question with image token (matches official docs)
375
+ question = "<image>\n" + self.user_prompt
376
+
377
+ # Generation config matching official examples
378
+ gen_cfg = dict(
379
+ max_new_tokens=128,
380
+ do_sample=False,
381
+ temperature=0.0,
382
+ # Optional: add other parameters from docs
383
+ # top_p=0.9,
384
+ # repetition_penalty=1.1
385
+ )
386
+
387
+ # Use model's chat method (official approach)
388
+ response = self.model.chat(self.tokenizer, pixel_values, question, gen_cfg)
389
+
390
+ # Clean up tensors to free GPU memory
391
+ del pixel_values
392
+ if torch.cuda.is_available():
393
+ torch.cuda.empty_cache()
394
+
395
+ return response.strip()
396
+
397
+ def _image_to_pixel_values(self, img, size=448, max_num=12):
398
+ transform = self._build_transform(size)
399
+ tiles = self._dynamic_preprocess(img, image_size=size, max_num=max_num, use_thumbnail=True)
400
+ pixel_values = torch.stack([transform(t) for t in tiles])
401
+ return pixel_values
402
+
403
+
404
+ def _dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
405
+ # same logic as the model card: split into tiles based on aspect ratio
406
+ w, h = image.size
407
+ aspect = w / h
408
+ targets = sorted({(i, j) for n in range(min_num, max_num+1)
409
+ for i in range(1, n+1) for j in range(1, n+1)
410
+ if i*j <= max_num and i*j >= min_num},
411
+ key=lambda x: x[0]*x[1])
412
+
413
+ # pick closest ratio
414
+ best = min(targets, key=lambda r: abs(aspect - r[0]/r[1]))
415
+ tw, th = image_size * best[0], image_size * best[1]
416
+ resized = image.resize((tw, th))
417
+
418
+ tiles = []
419
+ for i in range(best[0] * best[1]):
420
+ box = ((i % (tw // image_size)) * image_size,
421
+ (i // (tw // image_size)) * image_size,
422
+ ((i % (tw // image_size)) + 1) * image_size,
423
+ ((i // (tw // image_size)) + 1) * image_size)
424
+ tiles.append(resized.crop(box))
425
+
426
+ if use_thumbnail and len(tiles) != 1:
427
+ tiles.append(image.resize((image_size, image_size)))
428
+ return tiles
429
+
430
+ def _build_transform(self, input_size=448):
431
+ return T.Compose([
432
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
433
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
434
+ T.ToTensor(),
435
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
436
+ ])
437
+
438
+
439
+ # Global VLM Manager instance
440
+ vlm_manager = VLMManager()
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  absl-py==2.2.2
 
2
  aiofiles==23.2.1
3
  aiohappyeyeballs==2.6.1
4
  aiohttp==3.11.16
@@ -8,15 +9,25 @@ anyio==4.9.0
8
  astunparse==1.6.3
9
  async-timeout==5.0.1
10
  attrs==25.3.0
 
11
  bcrypt==4.3.0
12
  beautifulsoup4==4.13.3
 
 
 
13
  certifi==2025.1.31
14
  charset-normalizer==3.4.1
15
  click==8.1.8
 
 
16
  cycler==0.12.1
 
17
  datasets==3.5.0
 
18
  deep-translator==1.11.4
 
19
  dill==0.3.8
 
20
  et_xmlfile==2.0.0
21
  exceptiongroup==1.2.2
22
  fastapi==0.115.12
@@ -41,16 +52,36 @@ huggingface-hub==0.30.1
41
  idna==3.10
42
  Jinja2==3.1.6
43
  keras==3.9.2
 
 
44
  libclang==18.1.1
 
45
  Markdown==3.8
46
  markdown-it-py==3.0.0
47
  MarkupSafe==3.0.2
48
  mdurl==0.1.2
49
  ml_dtypes==0.5.1
 
50
  multidict==6.3.2
51
  multiprocess==0.70.16
 
52
  namex==0.0.8
 
53
  numpy==2.1.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  opencv-python==4.11.0.86
55
  openpyxl==3.1.5
56
  opt_einsum==3.4.0
@@ -59,54 +90,78 @@ orjson==3.10.16
59
  packaging==24.2
60
  pandas==2.2.3
61
  pillow==11.1.0
 
 
62
  propcache==0.3.1
63
  protobuf==5.29.4
 
64
  pyarrow==19.0.1
65
- pydantic==2.11.2
66
- pydantic_core==2.33.1
67
  pydub==0.25.1
68
  Pygments==2.19.1
69
  PySocks==1.7.1
 
70
  python-dateutil==2.9.0.post0
71
  python-dotenv==1.1.0
72
  python-multipart==0.0.20
73
  pytz==2025.2
 
74
  PyYAML==6.0.2
 
 
75
  requests==2.32.3
76
  retina-face==0.0.17
77
  rich==14.0.0
78
  ruff==0.11.4
79
  safehttpx==0.1.6
 
80
  semantic-version==2.10.0
81
  shellingham==1.5.4
82
  six==1.17.0
 
83
  sniffio==1.3.1
84
  soupsieve==2.6
 
 
 
 
 
 
85
  starlette==0.46.1
 
86
  tensorboard==2.19.0
87
  tensorboard-data-server==0.7.2
88
  tensorflow==2.19.0
89
  tensorflow-io-gcs-filesystem==0.37.1
90
  termcolor==3.0.1
91
  tf_keras==2.19.0
 
 
 
92
  tomlkit==0.13.2
 
 
 
 
93
  tqdm==4.67.1
 
 
94
  typer==0.15.2
95
- typing-inspection==0.4.0
96
- typing_extensions==4.12.2
97
  tzdata==2025.2
 
98
  urllib3==2.3.0
99
  uvicorn==0.34.0
 
 
100
  websockets==15.0.1
101
  Werkzeug==3.1.3
102
  wrapt==1.17.2
103
  xxhash==3.5.0
104
  yarl==1.19.0
105
- spacy_udpipe==1.0.0
106
- pyuca==1.2
107
- pillow_heif==1.0.0
108
- spacy==3.8.7
109
- spacy-legacy==3.0.12
110
- spacy-loggers==1.0.5
111
- spacy_thai==0.7.8
112
- spacy-udpipe==1.0.0
 
1
  absl-py==2.2.2
2
+ accelerate==1.9.0
3
  aiofiles==23.2.1
4
  aiohappyeyeballs==2.6.1
5
  aiohttp==3.11.16
 
9
  astunparse==1.6.3
10
  async-timeout==5.0.1
11
  attrs==25.3.0
12
+ av==15.1.0
13
  bcrypt==4.3.0
14
  beautifulsoup4==4.13.3
15
+ bitsandbytes==0.46.1
16
+ blis==1.3.0
17
+ catalogue==2.0.10
18
  certifi==2025.1.31
19
  charset-normalizer==3.4.1
20
  click==8.1.8
21
+ cloudpathlib==0.21.1
22
+ confection==0.1.5
23
  cycler==0.12.1
24
+ cymem==2.0.11
25
  datasets==3.5.0
26
+ decord==0.6.0
27
  deep-translator==1.11.4
28
+ deplacy==2.1.0
29
  dill==0.3.8
30
+ einops==0.8.1
31
  et_xmlfile==2.0.0
32
  exceptiongroup==1.2.2
33
  fastapi==0.115.12
 
52
  idna==3.10
53
  Jinja2==3.1.6
54
  keras==3.9.2
55
+ langcodes==3.5.0
56
+ language_data==1.3.0
57
  libclang==18.1.1
58
+ marisa-trie==1.2.1
59
  Markdown==3.8
60
  markdown-it-py==3.0.0
61
  MarkupSafe==3.0.2
62
  mdurl==0.1.2
63
  ml_dtypes==0.5.1
64
+ mpmath==1.3.0
65
  multidict==6.3.2
66
  multiprocess==0.70.16
67
+ murmurhash==1.0.13
68
  namex==0.0.8
69
+ networkx==3.4.2
70
  numpy==2.1.3
71
+ nvidia-cublas-cu12==12.6.4.1
72
+ nvidia-cuda-cupti-cu12==12.6.80
73
+ nvidia-cuda-nvrtc-cu12==12.6.77
74
+ nvidia-cuda-runtime-cu12==12.6.77
75
+ nvidia-cudnn-cu12==9.5.1.17
76
+ nvidia-cufft-cu12==11.3.0.4
77
+ nvidia-cufile-cu12==1.11.1.6
78
+ nvidia-curand-cu12==10.3.7.77
79
+ nvidia-cusolver-cu12==11.7.1.2
80
+ nvidia-cusparse-cu12==12.5.4.2
81
+ nvidia-cusparselt-cu12==0.6.3
82
+ nvidia-nccl-cu12==2.26.2
83
+ nvidia-nvjitlink-cu12==12.6.85
84
+ nvidia-nvtx-cu12==12.6.77
85
  opencv-python==4.11.0.86
86
  openpyxl==3.1.5
87
  opt_einsum==3.4.0
 
90
  packaging==24.2
91
  pandas==2.2.3
92
  pillow==11.1.0
93
+ pillow_heif==1.0.0
94
+ preshed==3.0.10
95
  propcache==0.3.1
96
  protobuf==5.29.4
97
+ psutil==7.0.0
98
  pyarrow==19.0.1
99
+ pydantic
100
+ pydantic_core
101
  pydub==0.25.1
102
  Pygments==2.19.1
103
  PySocks==1.7.1
104
+ pythainlp==5.1.2
105
  python-dateutil==2.9.0.post0
106
  python-dotenv==1.1.0
107
  python-multipart==0.0.20
108
  pytz==2025.2
109
+ pyuca==1.2
110
  PyYAML==6.0.2
111
+ qwen-vl-utils==0.0.8
112
+ regex==2024.11.6
113
  requests==2.32.3
114
  retina-face==0.0.17
115
  rich==14.0.0
116
  ruff==0.11.4
117
  safehttpx==0.1.6
118
+ safetensors==0.5.3
119
  semantic-version==2.10.0
120
  shellingham==1.5.4
121
  six==1.17.0
122
+ smart_open==7.3.0.post1
123
  sniffio==1.3.1
124
  soupsieve==2.6
125
+ spacy==3.8.7
126
+ spacy-legacy==3.0.12
127
+ spacy-loggers==1.0.5
128
+ spacy-thai==0.7.8
129
+ spacy-udpipe==1.0.0
130
+ srsly==2.5.1
131
  starlette==0.46.1
132
+ sympy==1.14.0
133
  tensorboard==2.19.0
134
  tensorboard-data-server==0.7.2
135
  tensorflow==2.19.0
136
  tensorflow-io-gcs-filesystem==0.37.1
137
  termcolor==3.0.1
138
  tf_keras==2.19.0
139
+ thinc==8.3.6
140
+ timm==1.0.19
141
+ tokenizers==0.21.2
142
  tomlkit==0.13.2
143
+ torch==2.7.1
144
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
145
+ torchao==0.13.0
146
+ torchvision==0.22.1
147
  tqdm==4.67.1
148
+ transformers==4.53.3
149
+ triton==3.3.1
150
  typer==0.15.2
151
+ typing-inspection
152
+ typing_extensions
153
  tzdata==2025.2
154
+ ufal.udpipe==1.3.1.1
155
  urllib3==2.3.0
156
  uvicorn==0.34.0
157
+ wasabi==1.1.3
158
+ weasel==0.4.1
159
  websockets==15.0.1
160
  Werkzeug==3.1.3
161
  wrapt==1.17.2
162
  xxhash==3.5.0
163
  yarl==1.19.0
164
+ supabase==2.18.1
165
+ supabase_auth==2.12.3
166
+ supabase_functions==0.10.1
167
+ # flash_attn==2.8.1
 
 
 
 
ui/layout.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import time
 
3
 
4
  from logic.data_utils import CustomHFDatasetSaver
5
  from data.lang2eng_map import lang2eng_mapping
@@ -12,6 +13,161 @@ from .selection_page import build_selection_page
12
  from .main_page import build_main_page
13
  from .main_page import sort_with_pyuca
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def get_key_by_value(dictionary, value):
16
  for key, val in dictionary.items():
17
  if val == value:
@@ -100,14 +256,27 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
100
  object-fit: contain; /* make sure the full image shows */
101
  height: 460px; /* set a fixed height */
102
  }
 
 
 
 
 
 
 
 
 
 
103
  """
104
  ############################################################################
105
  with gr.Blocks(css=custom_css) as ui:
 
 
 
106
  local_storage = gr.State([None, None, "", ""])
107
  loading_example = gr.State(False) # to check if the values are loaded from a user click on an example in
108
  # First page: selection
109
 
110
- selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown = build_selection_page(metadata_dict)
111
 
112
  # Second page
113
  cmp_main_ui = build_main_page(concepts_dict, metadata_dict, local_storage)
@@ -144,8 +313,20 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
144
  modal_exclude_confirm = cmp_main_ui["modal_exclude_confirm"]
145
  cancel_exclude_btn = cmp_main_ui["cancel_exclude_btn"]
146
  confirm_exclude_btn = cmp_main_ui["confirm_exclude_btn"]
147
-
148
-
 
 
 
 
 
 
 
 
 
 
 
 
149
  ### Category button
150
  category_btn.change(
151
  fn=partial(load_concepts, concepts=concepts_dict),
@@ -214,7 +395,7 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
214
  clear_btn.click(
215
  fn=clear_data,
216
  outputs=[
217
- image_inp, image_url_inp, long_caption_inp, exampleid_btn,
218
  category_btn, concept_btn,
219
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
220
  category_concept_dropdowns[3], category_concept_dropdowns[4]
@@ -280,12 +461,12 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
280
  # Handle clicking on an example
281
  user_examples.click(
282
  fn=partial(handle_click_example, concepts_dict=concepts_dict),
283
- inputs=[user_examples],
284
  outputs=[
285
  image_inp, image_url_inp, long_caption_inp, exampleid_btn,
286
  category_btn, concept_btn,
287
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
288
- category_concept_dropdowns[3], category_concept_dropdowns[4], loading_example
289
  ],
290
  )
291
 
@@ -295,6 +476,41 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
295
 
296
  # ============================================ #
297
  # Submit Button Click events
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  proceed_btn.click(
300
  fn=partial(switch_ui, flag=False),
@@ -313,8 +529,8 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
313
  ]
314
  ).then(
315
  fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
316
- inputs=[username_inp, password_inp, country_choice, language_choice],
317
- outputs=[user_examples, loading_msg],
318
  )
319
 
320
 
@@ -322,7 +538,7 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
322
  exit_btn.click(
323
  fn=exit,
324
  outputs=[
325
- image_inp, image_url_inp, long_caption_inp, user_examples, loading_msg,
326
  username, password, local_storage, exampleid_btn, category_btn, concept_btn,
327
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
328
  category_concept_dropdowns[3], category_concept_dropdowns[4]
@@ -368,7 +584,10 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
368
  "excluded": gr.State(value=False),
369
  "concepts_dict": gr.State(value=concepts_dict),
370
  "country_lang_map": gr.State(value=lang2eng_mapping),
 
371
  # "is_blurred": is_blurred
 
 
372
  }
373
  # data_outputs = [image_inp, image_url_inp, long_caption_inp,
374
  # country_inp, language_inp, category_btn, concept_btn,
@@ -376,34 +595,56 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
376
  hf_writer.setup(list(data_outputs.keys()), local_ds_folder = LOCAL_DS_DIRECTORY_PATH)
377
 
378
  # STEP 4: Chain save_data, then update_user_data, then re-enable button, hide modal, and clear
379
- submit_btn.click(
380
- hf_writer.save,
381
- list(data_outputs.values()),
382
- None,
383
- ).success(
384
- fn=partial(clear_data, "submit"),
385
- outputs=[
386
- image_inp, image_url_inp, long_caption_inp, exampleid_btn,
387
- category_btn, concept_btn,
388
- category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
389
- category_concept_dropdowns[3], category_concept_dropdowns[4]
390
- ],
391
- # ).success(enable_submit,
392
- # None, [submit_btn]
393
- # ).success(lambda: Modal(visible=False),
394
- # None, modal_saving
395
- # ).success(lambda: Modal(visible=True),
396
- # None, modal_data_saved
397
- ).success(
398
- # set loading msg
399
- lambda: gr.update(value="**Loading your data, please wait ...**"),
400
- None, loading_msg
401
- ).success(
402
- fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
403
- inputs=[username_inp, password_inp, country_choice, language_choice],
404
- outputs=[user_examples, loading_msg]
405
- )
406
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  # ============================================ #
408
  # instructions button
409
  instruct_btn.click(lambda: Modal(visible=True), None, modal)
@@ -446,13 +687,13 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
446
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
447
  category_concept_dropdowns[3], category_concept_dropdowns[4],
448
  timestamp_btn, username_inp, password_inp, exampleid_btn, gr.State(value=True),
449
- gr.State(value=concepts_dict), gr.State(value=lang2eng_mapping)
450
  ],
451
  outputs=None
452
  ).success(
453
  fn=partial(clear_data, "remove"),
454
- outputs=[
455
- image_inp, image_url_inp, long_caption_inp, exampleid_btn,
456
  category_btn, concept_btn,
457
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
458
  category_concept_dropdowns[3], category_concept_dropdowns[4]
@@ -465,8 +706,32 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
465
  outputs=loading_msg
466
  ).success(
467
  fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path=LOCAL_DS_DIRECTORY_PATH),
468
- inputs=[username_inp, password_inp, country_choice, language_choice],
469
- outputs=[user_examples, loading_msg]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  )
 
471
 
472
- return ui
 
1
  import gradio as gr
2
  import time
3
+ from logic.supabase_client import auth_handler
4
 
5
  from logic.data_utils import CustomHFDatasetSaver
6
  from data.lang2eng_map import lang2eng_mapping
 
13
  from .main_page import build_main_page
14
  from .main_page import sort_with_pyuca
15
 
16
+ js_code = """
17
+ function() {
18
+ // Get the full URL with the fragment
19
+ const url = window.location.href;
20
+ const fragment = url.split('#')[1];
21
+
22
+ if (!fragment) {
23
+ return "";
24
+ }
25
+
26
+ // Parse the fragment into an object
27
+ const params = new URLSearchParams(fragment);
28
+ const access_token = params.get('access_token');
29
+ const refresh_token = params.get('refresh_token');
30
+
31
+ // Create a JSON string with the tokens
32
+ const tokens = JSON.stringify({
33
+ access_token: access_token,
34
+ refresh_token: refresh_token
35
+ });
36
+
37
+ // Return the JSON string to the Gradio output component
38
+ return tokens;
39
+ }
40
+ """
41
+
42
+ def login_user(email, password):
43
+ result = auth_handler.login(email, password)
44
+ if result['success']:
45
+ session_data = result['data']
46
+ persistent_data = {
47
+ "refresh_token": session_data['refresh_token'],
48
+ "user_email": session_data['user_email']
49
+ }
50
+ return session_data['client'], persistent_data, result['message']
51
+ else:
52
+ persistent_data = {
53
+ "refresh_token": "",
54
+ "user_email": ""
55
+ }
56
+ return None, persistent_data, result['message']
57
+
58
+ def login_user_recovery(session_data: str):
59
+ """
60
+ This function receives session data (tokens as a JSON string) from the frontend,
61
+ retrieves the session, and returns data in a format similar to login_user.
62
+ """
63
+ try:
64
+ import json
65
+ tokens = json.loads(session_data)
66
+ access_token = tokens.get("access_token")
67
+ refresh_token = tokens.get("refresh_token")
68
+
69
+ if not access_token or not refresh_token:
70
+ return None, gr.skip(), "Invalid session data provided."
71
+
72
+ result = auth_handler.retrieve_session_from_tokens(access_token, refresh_token)
73
+
74
+ if result['success']:
75
+ session_data_result = result['data']
76
+ persistent_data = {
77
+ "refresh_token": session_data_result['refresh_token'],
78
+ "user_email": session_data_result['user_email']
79
+ }
80
+ return session_data_result['client'], persistent_data, result['message']
81
+ else:
82
+ persistent_data = {
83
+ "refresh_token": "",
84
+ "user_email": ""
85
+ }
86
+ return None, persistent_data, result['message']
87
+
88
+ except Exception as e:
89
+ return None, gr.skip(), f"Failed to process recovery login: {e}"
90
+
91
+ def sign_up(email, password):
92
+ result = auth_handler.sign_up(email, password)
93
+ return result['message']
94
+
95
+ def reset_password(email):
96
+ result = auth_handler.reset_password_for_email(email)
97
+ return result['message']
98
+
99
+ def log_out(supabase_user_client, persistent_session):
100
+ """
101
+ Logs out the user and clears the session. If error occurs, it returns an empty persistent session (logging out user).
102
+ """
103
+ persistent_session = {
104
+ "refresh_token": "",
105
+ "user_email": ""
106
+ }
107
+ if supabase_user_client:
108
+ result = auth_handler.logout(supabase_user_client)
109
+ if result['success']:
110
+ print("User logged out successfully.")
111
+ return persistent_session
112
+ else:
113
+ print(f"Error logging out: {result['message']}")
114
+ return persistent_session
115
+ else:
116
+ print("No user client provided to log out.")
117
+ return persistent_session
118
+
119
+ def restore_user_session(session_data, login_status=None):
120
+ print("Restoring user session with data:", session_data)
121
+ # defualt values if the user is not logged in
122
+ # or the session data is not valid
123
+ login_status_update = gr.update(value= login_status if login_status else "")
124
+ proceed_button_update = gr.update(value="Proceed as Anonymous User", interactive=True)
125
+ login_button_update = gr.update(visible=True)
126
+ sign_up_button_update = gr.update(visible=True)
127
+ reset_password_button_update = gr.update(visible=True)
128
+ logout_button_update = gr.update(visible=False)
129
+ change_password_field_update = gr.update(visible=False)
130
+ change_password_field_confirm_update = gr.update(visible=False)
131
+ change_password_button_update = gr.update(visible=False)
132
+ change_password_status_update = gr.update(value="")
133
+ persistent_data = {
134
+ "refresh_token": "",
135
+ "user_email": ""
136
+ }
137
+ if not session_data or not session_data.get('refresh_token', ''):
138
+ print("No session data found, proceeding as anonymous user.")
139
+ return None, persistent_data, login_status_update, proceed_button_update, login_button_update, sign_up_button_update, reset_password_button_update, logout_button_update, change_password_field_update, change_password_field_confirm_update, change_password_button_update, change_password_status_update
140
+
141
+ result = auth_handler.restore_session(session_data['refresh_token'])
142
+ if result['success']:
143
+ restored_session = result['data']
144
+ new_persistent_data = {
145
+ "refresh_token": restored_session['refresh_token'],
146
+ "user_email": restored_session['user_email']
147
+ }
148
+ login_status_update = gr.update(value=result['message'])
149
+ proceed_button_update = gr.update(value="Proceed", interactive=True)
150
+ login_button_update = gr.update(visible=False)
151
+ sign_up_button_update = gr.update(visible=False)
152
+ reset_password_button_update = gr.update(visible=False)
153
+ logout_button_update = gr.update(visible=True)
154
+ change_password_field_update = gr.update(visible=True)
155
+ change_password_field_confirm_update = gr.update(visible=True)
156
+ change_password_button_update = gr.update(visible=True)
157
+ return restored_session['client'], new_persistent_data, login_status_update, proceed_button_update, login_button_update, sign_up_button_update, reset_password_button_update, logout_button_update, change_password_field_update, change_password_field_confirm_update, change_password_button_update, change_password_status_update
158
+ else:
159
+ return None, persistent_data, login_status_update, proceed_button_update, login_button_update, sign_up_button_update, reset_password_button_update, logout_button_update, change_password_field_update, change_password_field_confirm_update, change_password_button_update, change_password_status_update
160
+
161
+ def change_password(supabase_user_client, new_password, confirm_password):
162
+ """
163
+ Changes the user's password.
164
+ """
165
+ if new_password != confirm_password:
166
+ return "Passwords do not match. Please try again."
167
+ result = auth_handler.change_password(supabase_user_client, new_password)
168
+ return result['message']
169
+
170
+
171
  def get_key_by_value(dictionary, value):
172
  for key, val in dictionary.items():
173
  if val == value:
 
256
  object-fit: contain; /* make sure the full image shows */
257
  height: 460px; /* set a fixed height */
258
  }
259
+ #vlm_output .input-container {
260
+ position: relative;
261
+ }
262
+ #vlm_output .input-container::before {
263
+ content: "";
264
+ position: absolute;
265
+ top: 0; left: 0; right: 0; bottom: 0;
266
+ z-index: 10; /* sits above the textarea */
267
+ background: transparent;
268
+ }
269
  """
270
  ############################################################################
271
  with gr.Blocks(css=custom_css) as ui:
272
+ supabase_user_client = gr.State(None)
273
+ persistent_session = gr.BrowserState(None)
274
+
275
  local_storage = gr.State([None, None, "", ""])
276
  loading_example = gr.State(False) # to check if the values are loaded from a user click on an example in
277
  # First page: selection
278
 
279
+ selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown, login_btn, sign_up_btn, reset_password_btn, login_status, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status = build_selection_page(metadata_dict)
280
 
281
  # Second page
282
  cmp_main_ui = build_main_page(concepts_dict, metadata_dict, local_storage)
 
313
  modal_exclude_confirm = cmp_main_ui["modal_exclude_confirm"]
314
  cancel_exclude_btn = cmp_main_ui["cancel_exclude_btn"]
315
  confirm_exclude_btn = cmp_main_ui["confirm_exclude_btn"]
316
+ vlm_output = cmp_main_ui["vlm_output"]
317
+ gen_button = cmp_main_ui["gen_button"]
318
+ vlm_feedback = cmp_main_ui["vlm_feedback"]
319
+ modal_vlm = cmp_main_ui["modal_vlm"]
320
+ vlm_no_btn = cmp_main_ui["vlm_no_btn"]
321
+ vlm_done_btn = cmp_main_ui["vlm_done_btn"]
322
+ submit_yes = cmp_main_ui["submit_yes"]
323
+ submit_no = cmp_main_ui["submit_no"]
324
+ modal_submit = cmp_main_ui["modal_submit"]
325
+ vlm_cancel_btn = cmp_main_ui["vlm_cancel_btn"]
326
+ vlm_model_dropdown = cmp_main_ui["vlm_model_dropdown"]
327
+
328
+ # dictionary to store all vlm_output by exampleid
329
+ vlm_captions = gr.State(None)
330
  ### Category button
331
  category_btn.change(
332
  fn=partial(load_concepts, concepts=concepts_dict),
 
395
  clear_btn.click(
396
  fn=clear_data,
397
  outputs=[
398
+ image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, exampleid_btn,
399
  category_btn, concept_btn,
400
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
401
  category_concept_dropdowns[3], category_concept_dropdowns[4]
 
461
  # Handle clicking on an example
462
  user_examples.click(
463
  fn=partial(handle_click_example, concepts_dict=concepts_dict),
464
+ inputs=[user_examples, vlm_captions],
465
  outputs=[
466
  image_inp, image_url_inp, long_caption_inp, exampleid_btn,
467
  category_btn, concept_btn,
468
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
469
+ category_concept_dropdowns[3], category_concept_dropdowns[4], loading_example, vlm_output
470
  ],
471
  )
472
 
 
476
 
477
  # ============================================ #
478
  # Submit Button Click events
479
+ login_btn.click(
480
+ fn=login_user,
481
+ inputs=[username, password],
482
+ outputs=[supabase_user_client, persistent_session, login_status],
483
+ ).then(
484
+ fn=restore_user_session,
485
+ inputs=[persistent_session, login_status],
486
+ outputs=[supabase_user_client, persistent_session, login_status, proceed_btn, login_btn, sign_up_btn, reset_password_btn, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status],
487
+ )
488
+
489
+ sign_up_btn.click(
490
+ fn=sign_up,
491
+ inputs=[username, password],
492
+ outputs=[login_status],
493
+ )
494
+
495
+ logout_btn.click(
496
+ fn=log_out,
497
+ inputs=[supabase_user_client, persistent_session],
498
+ outputs=[persistent_session]
499
+ ).then(
500
+ fn=restore_user_session,
501
+ inputs=[persistent_session],
502
+ outputs=[supabase_user_client, persistent_session, login_status, proceed_btn, login_btn, sign_up_btn, reset_password_btn, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status],
503
+ )
504
+ change_password_btn.click(
505
+ fn=change_password,
506
+ inputs=[supabase_user_client, change_password_field, change_password_field_confirm],
507
+ outputs=[change_password_status]
508
+ )
509
+ reset_password_btn.click(
510
+ fn=reset_password,
511
+ inputs=[username],
512
+ outputs=[login_status]
513
+ )
514
 
515
  proceed_btn.click(
516
  fn=partial(switch_ui, flag=False),
 
529
  ]
530
  ).then(
531
  fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
532
+ inputs=[supabase_user_client, country_choice, language_choice],
533
+ outputs=[user_examples, loading_msg, vlm_captions],
534
  )
535
 
536
 
 
538
  exit_btn.click(
539
  fn=exit,
540
  outputs=[
541
+ image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, user_examples, loading_msg,
542
  username, password, local_storage, exampleid_btn, category_btn, concept_btn,
543
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
544
  category_concept_dropdowns[3], category_concept_dropdowns[4]
 
584
  "excluded": gr.State(value=False),
585
  "concepts_dict": gr.State(value=concepts_dict),
586
  "country_lang_map": gr.State(value=lang2eng_mapping),
587
+ "client": supabase_user_client,
588
  # "is_blurred": is_blurred
589
+ "vlm_caption": vlm_output,
590
+ "vlm_feedback": vlm_feedback
591
  }
592
  # data_outputs = [image_inp, image_url_inp, long_caption_inp,
593
  # country_inp, language_inp, category_btn, concept_btn,
 
595
  hf_writer.setup(list(data_outputs.keys()), local_ds_folder = LOCAL_DS_DIRECTORY_PATH)
596
 
597
  # STEP 4: Chain save_data, then update_user_data, then re-enable button, hide modal, and clear
598
+ # submit_btn.click(lambda: Modal(visible=True), None, modal_vlm)
599
+ submit_btn.click(submit_button_clicked,
600
+ inputs=[vlm_output],
601
+ outputs=[modal_vlm, modal_submit])
602
+
603
+ # submit_btn.click(partial(submit_button_clicked, save_fn=hf_writer.save,
604
+ # data_outputs=data_outputs),
605
+ # inputs=[vlm_output],
606
+ # outputs=[modal_vlm, image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, exampleid_btn,
607
+ # category_btn, concept_btn, category_concept_dropdowns[0], category_concept_dropdowns[1],
608
+ # category_concept_dropdowns[2], category_concept_dropdowns[3], category_concept_dropdowns[4]])
609
+
610
+ def wire_submit_chain(button, modal_ui):
611
+ e = button.click(
612
+ fn=lambda: Modal(visible=False),
613
+ outputs=[modal_ui]
614
+ ).success(
615
+ hf_writer.save,
616
+ inputs = list(data_outputs.values()),
617
+ outputs = None,
618
+ ).success(
619
+ fn=partial(clear_data, "submit"),
620
+ outputs=[
621
+ image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, exampleid_btn,
622
+ category_btn, concept_btn,
623
+ category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
624
+ category_concept_dropdowns[3], category_concept_dropdowns[4]
625
+ ],
626
+ # ).success(enable_submit,
627
+ # None, [submit_btn]
628
+ # ).success(lambda: Modal(visible=False),
629
+ # None, modal_saving
630
+ # ).success(lambda: Modal(visible=True),
631
+ # None, modal_data_saved
632
+ ).success(
633
+ # set loading msg
634
+ lambda: gr.update(value="**Loading your data, please wait ...**"),
635
+ None, loading_msg
636
+ ).success(
637
+ fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
638
+ inputs=[supabase_user_client, country_choice, language_choice],
639
+ outputs=[user_examples, loading_msg, vlm_captions]
640
+ )
641
+ return e
642
+
643
+ wire_submit_chain(vlm_done_btn, modal_vlm)
644
+ wire_submit_chain(vlm_no_btn, modal_vlm)
645
+ wire_submit_chain(submit_yes, modal_submit)
646
+ submit_no.click(lambda: Modal(visible=False), None, modal_submit)
647
+ vlm_cancel_btn.click(lambda: Modal(visible=False), None, modal_vlm)
648
  # ============================================ #
649
  # instructions button
650
  instruct_btn.click(lambda: Modal(visible=True), None, modal)
 
687
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
688
  category_concept_dropdowns[3], category_concept_dropdowns[4],
689
  timestamp_btn, username_inp, password_inp, exampleid_btn, gr.State(value=True),
690
+ gr.State(value=concepts_dict), gr.State(value=lang2eng_mapping), vlm_output, vlm_feedback
691
  ],
692
  outputs=None
693
  ).success(
694
  fn=partial(clear_data, "remove"),
695
+ outputs=[
696
+ image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, exampleid_btn,
697
  category_btn, concept_btn,
698
  category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
699
  category_concept_dropdowns[3], category_concept_dropdowns[4]
 
706
  outputs=loading_msg
707
  ).success(
708
  fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path=LOCAL_DS_DIRECTORY_PATH),
709
+ inputs=[supabase_user_client, country_choice, language_choice],
710
+ outputs=[user_examples, loading_msg, vlm_captions]
711
+ )
712
+ # ============================================= #
713
+ # VLM Gen button
714
+ # ============================================= #
715
+ gen_button.click(
716
+ fn=generate_vlm_caption, # processor=processor, model=model
717
+ inputs=[image_inp, vlm_model_dropdown],
718
+ outputs=[vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button]
719
+ )
720
+ # vlm_output.change(
721
+ # fn=lambda : gr.update(interactive=False) if vlm_output.value else gr.update(interactive=True),
722
+ # inputs=[],
723
+ # outputs=[gen_button]
724
+ # )
725
+
726
+ ui.load(
727
+ fn=login_user_recovery,
728
+ inputs=gr.Textbox(visible=False, value=""), # hidden textbox to get the url tokens
729
+ outputs=[supabase_user_client, persistent_session, login_status],
730
+ js=js_code
731
+ ).then(
732
+ fn=restore_user_session,
733
+ inputs=[persistent_session],
734
+ outputs=[supabase_user_client, persistent_session, login_status, proceed_btn, login_btn, sign_up_btn, reset_password_btn, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status],
735
  )
736
+ return ui
737
 
 
ui/main_page.py CHANGED
@@ -107,7 +107,36 @@ def build_main_page(concepts_dict, metadata_dict, local_storage):
107
  long_caption_inp = gr.Textbox(lines=6, label="Description", elem_id="long_caption_inp")
108
  num_words_inp = gr.Textbox(lines=1, label="Number of words", elem_id="num_words", interactive=False, value=0)
109
  # num_words_inp = gr.Markdown("Number of words", elem_id="num_words")
 
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  categories_list = sort_with_pyuca(list(concepts_dict["USA"]["English"].keys()))
112
 
113
  def create_category_dropdown(category, index):
@@ -226,5 +255,16 @@ def build_main_page(concepts_dict, metadata_dict, local_storage):
226
  "modal_exclude_confirm": modal_exclude_confirm,
227
  "cancel_exclude_btn": cancel_exclude_btn,
228
  "confirm_exclude_btn": confirm_exclude_btn,
 
 
 
 
 
 
 
 
 
 
 
229
  }
230
  return output_dict
 
107
  long_caption_inp = gr.Textbox(lines=6, label="Description", elem_id="long_caption_inp")
108
  num_words_inp = gr.Textbox(lines=1, label="Number of words", elem_id="num_words", interactive=False, value=0)
109
  # num_words_inp = gr.Markdown("Number of words", elem_id="num_words")
110
+ #########################################################
111
+ with Modal(visible=False, allow_user_close=False) as modal_vlm:
112
 
113
+ question = gr.Markdown("Would you like to see if a VLM can generate a culturally aware description for your uploaded concept?")
114
+ with gr.Row():
115
+ gen_button = gr.Button("Yes", variant="primary", elem_id="generate_answer_btn")
116
+ vlm_no_btn = gr.Button("No")
117
+ vlm_cancel_btn = gr.Button("Cancel")
118
+ vlm_model_dropdown = gr.Dropdown(
119
+ ["SmolVLM-500M", "Qwen2.5-VL-7B", "InternVL3_5-8B", "Gemma3-4B"], value="Gemma3-4B", multiselect=False, label="VLM Model", info="Select the VLM model to use for generating the description."
120
+ )
121
+ vlm_output = gr.Textbox(lines=6, label="Generated description", elem_id="vlm_output", interactive=False)
122
+ vlm_feedback = gr.Radio(["Yes πŸ‘", "No πŸ‘Ž"], label="Do you think the generated description is accurate within the cultural context of your country?", visible=False, elem_id="vlm_feedback", interactive=True)
123
+ vlm_done_btn = gr.Button("Complete Submission", visible=False)
124
+
125
+ with Modal(visible=False, allow_user_close=False) as modal_submit:
126
+
127
+ gr.Markdown("⚠️ You've already generated a caption for this image. An optional description with the VLM can only be generated once. Would you like to proceed and submit your modified data?")
128
+ with gr.Row():
129
+ submit_yes = gr.Button("Yes", variant="primary", elem_id="submit_confirm_yes")
130
+ submit_no = gr.Button("No", variant="stop", elem_id="submit_confirm_no")
131
+
132
+ # with gr.Group():
133
+ # gr.Markdown("### VLM Generation (Optional)")
134
+ # with gr.Accordion("πŸ“˜ Click here if you want to get a generated answer from a small vlm", open=False):
135
+ # gen_button = gr.Button("Generate Answer", variant="primary", elem_id="generate_answer_btn")
136
+ # vlm_output = gr.Textbox(lines=6, label="Generated Answer", elem_id="vlm_output", interactive=False)
137
+ # vlm_feedback = gr.Radio(["Yes πŸ‘", "No πŸ‘Ž"], label="Do you like the generated caption?", visible=False, elem_id="vlm_feedback", interactive=True)
138
+ ##########################################################
139
+
140
  categories_list = sort_with_pyuca(list(concepts_dict["USA"]["English"].keys()))
141
 
142
  def create_category_dropdown(category, index):
 
255
  "modal_exclude_confirm": modal_exclude_confirm,
256
  "cancel_exclude_btn": cancel_exclude_btn,
257
  "confirm_exclude_btn": confirm_exclude_btn,
258
+ "vlm_output": vlm_output,
259
+ "gen_button": gen_button,
260
+ "vlm_feedback": vlm_feedback,
261
+ "modal_vlm": modal_vlm,
262
+ "vlm_no_btn": vlm_no_btn,
263
+ "vlm_done_btn": vlm_done_btn,
264
+ "submit_yes": submit_yes,
265
+ "submit_no": submit_no,
266
+ "modal_submit": modal_submit,
267
+ "vlm_cancel_btn": vlm_cancel_btn,
268
+ "vlm_model_dropdown": vlm_model_dropdown
269
  }
270
  return output_dict
ui/selection_page.py CHANGED
@@ -57,6 +57,24 @@ def build_selection_page(metadata_dict):
57
  username = gr.Textbox(label="Email (optional)", type="email", elem_id="username_text")
58
  password = gr.Textbox(label="Password (optional)", type="password", elem_id="password_text")
59
 
60
- proceed_btn = gr.Button("Proceed")
 
 
 
 
61
 
62
- return selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  username = gr.Textbox(label="Email (optional)", type="email", elem_id="username_text")
58
  password = gr.Textbox(label="Password (optional)", type="password", elem_id="password_text")
59
 
60
+ with gr.Row():
61
+ login_btn = gr.Button("Login", elem_id="login_btn")
62
+ sign_up_btn = gr.Button("Sign up", elem_id="sign_up_btn")
63
+ reset_password_btn = gr.Button("Reset Password", elem_id="reset_password_btn")
64
+ logout_btn = gr.Button("Logout", elem_id="logout_btn",visible=False)
65
 
66
+ login_status = gr.Markdown("")
67
+ with gr.Row():
68
+ proceed_btn = gr.Button("Proceed")
69
+ with gr.Row():
70
+ change_password_field = gr.Textbox(
71
+ label="Change Password", type="password", elem_id="change_password_field", visible=True
72
+ )
73
+ change_password_field_confirm = gr.Textbox(
74
+ label="Confirm New Password", type="password", elem_id="change_password_field_confirm", visible=True
75
+ )
76
+ with gr.Row():
77
+ change_password_btn = gr.Button("Change Password", elem_id="change_password_btn", visible=True)
78
+ change_password_status = gr.Markdown("")
79
+
80
+ return selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown, login_btn, sign_up_btn, reset_password_btn, login_status, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status