matsant01 commited on
Commit
c61e1ad
·
1 Parent(s): fc4cd8f

Pushing collected preferences to hf dataset

Browse files
Files changed (2) hide show
  1. app.py +160 -67
  2. requirements.txt +3 -1
app.py CHANGED
@@ -3,17 +3,67 @@ import os
3
  import random
4
  import csv
5
  from pathlib import Path
6
- from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  DATA_DIR = Path("data")
9
- RESULTS_DIR = Path("results")
10
- RESULTS_FILE = RESULTS_DIR / "preferences.csv"
11
  IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"]
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # --- Data Loading ---
14
 
15
  def find_image(folder_path: Path, base_name: str) -> Path | None:
16
- """Finds an image file starting with base_name in a folder."""
17
  for ext in IMAGE_EXTENSIONS:
18
  file_path = folder_path / f"{base_name}{ext}"
19
  if file_path.exists():
@@ -21,12 +71,10 @@ def find_image(folder_path: Path, base_name: str) -> Path | None:
21
  return None
22
 
23
  def get_sample_ids() -> list[str]:
24
- """Scans the data directory for valid sample IDs."""
25
  sample_ids = []
26
  if DATA_DIR.is_dir():
27
  for item in DATA_DIR.iterdir():
28
  if item.is_dir():
29
- # Check if required files exist
30
  prompt_file = item / "prompt.txt"
31
  input_bg = find_image(item, "input_bg")
32
  input_fg = find_image(item, "input_fg")
@@ -37,7 +85,6 @@ def get_sample_ids() -> list[str]:
37
  return sample_ids
38
 
39
  def load_sample_data(sample_id: str) -> dict | None:
40
- """Loads data for a specific sample ID."""
41
  sample_path = DATA_DIR / sample_id
42
  if not sample_path.is_dir():
43
  return None
@@ -72,7 +119,6 @@ def load_sample_data(sample_id: str) -> dict | None:
72
  INITIAL_SAMPLE_IDS = get_sample_ids()
73
 
74
  def get_next_sample(available_ids: list[str]) -> tuple[dict | None, list[str]]:
75
- """Selects a random sample ID from the available list."""
76
  if not available_ids:
77
  return None, []
78
  chosen_id = random.choice(available_ids)
@@ -80,9 +126,7 @@ def get_next_sample(available_ids: list[str]) -> tuple[dict | None, list[str]]:
80
  sample_data = load_sample_data(chosen_id)
81
  return sample_data, remaining_ids
82
 
83
-
84
  def display_new_sample(state: dict, available_ids: list[str]):
85
- """Loads and prepares a new sample for display."""
86
  sample_data, remaining_ids = get_next_sample(available_ids)
87
 
88
  if not sample_data:
@@ -129,16 +173,15 @@ def display_new_sample(state: dict, available_ids: list[str]):
129
  }
130
 
131
  def record_preference(choice: str, state: dict, request: gr.Request):
132
- """Records the user's preference and prepares for the next sample."""
133
- if not request: # Add a check if request is None
134
  print("Error: Request object is None. Cannot get session ID.")
135
- session_id = "unknown_session" # Fallback session ID
136
  else:
137
  try:
138
- session_id = request.client.host # Use IP address as a basic session identifier
139
  except AttributeError:
140
- print("Error: request.client is None or has no 'host' attribute.")
141
- session_id = "unknown_client" # Fallback if client object is weird
142
 
143
  if not state or "current_sample_id" not in state:
144
  print("Warning: State missing, cannot record preference.")
@@ -147,67 +190,97 @@ def record_preference(choice: str, state: dict, request: gr.Request):
147
  choice_button_b: gr.update(interactive=False),
148
  next_button: gr.update(visible=True, interactive=True),
149
  status_display: gr.update(value="Error: Session state lost. Click Next Sample."),
150
- app_state: state # Return unchanged state
151
  }
152
 
153
  chosen_model_name = state["output_a_model_name"] if choice == "A" else state["output_b_model_name"]
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- # Ensure results directory exists
156
- RESULTS_DIR.mkdir(parents=True, exist_ok=True)
157
-
158
- # Append result to CSV
159
- file_exists = RESULTS_FILE.exists()
160
  try:
161
- with open(RESULTS_FILE, 'a', newline='', encoding='utf-8') as f:
162
- writer = csv.writer(f)
163
- if not file_exists:
164
- writer.writerow([
165
- "timestamp", "session_id", "sample_id",
166
- "baseline_displayed_as", "tficon_displayed_as",
167
- "chosen_display", "chosen_model_name"
168
- ]) # Header
169
-
170
- baseline_display = "A" if state["output_a_model_name"] == "baseline" else "B"
171
- tficon_display = "B" if state["output_a_model_name"] == "baseline" else "A"
172
-
173
- writer.writerow([
174
- datetime.now().isoformat(),
175
- session_id,
176
- state["current_sample_id"],
177
- baseline_display,
178
- tficon_display,
179
- choice, # A or B
180
- chosen_model_name # baseline or tf-icon
181
- ])
182
  except Exception as e:
183
- print(f"Error writing results: {e}")
184
  return {
185
  choice_button_a: gr.update(interactive=False),
186
  choice_button_b: gr.update(interactive=False),
187
- next_button: gr.update(visible=True, interactive=True), # Allow user to continue
188
- status_display: gr.update(value=f"Error saving preference: {e}. Click Next Sample."),
189
  app_state: state
190
  }
191
 
192
-
193
- # Update UI: disable choice buttons, show next button
194
  return {
195
  choice_button_a: gr.update(interactive=False),
196
  choice_button_b: gr.update(interactive=False),
197
  next_button: gr.update(visible=True, interactive=True),
198
  status_display: gr.update(value=f"Preference recorded (Chose {choice}). Click Next Sample."),
199
- app_state: state # Return unchanged state
200
  }
201
 
202
- # --- New Handler Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def handle_choice_a(state: dict, request: gr.Request):
204
  return record_preference("A", state, request)
205
 
206
  def handle_choice_b(state: dict, request: gr.Request):
207
  return record_preference("B", state, request)
208
 
209
- # --- Gradio Interface ---
210
-
211
  with gr.Blocks(title="Image Composition User Study") as demo:
212
  gr.Markdown("# Image Composition User Study")
213
  gr.Markdown(
@@ -215,12 +288,9 @@ with gr.Blocks(title="Image Composition User Study") as demo:
215
  "Then, compare the two output images (Output A and Output B) and click the button below the one you prefer."
216
  )
217
 
218
- # State variables
219
- app_state = gr.State({}) # Stores current sample info (id, output mapping)
220
- # Keep track of samples available *for this session*
221
  available_samples_state = gr.State(INITIAL_SAMPLE_IDS)
222
 
223
- # Displays
224
  prompt_display = gr.Textbox(label="Prompt", interactive=False)
225
  status_display = gr.Textbox(label="Status", value="Loading first sample...", interactive=False)
226
 
@@ -241,9 +311,6 @@ with gr.Blocks(title="Image Composition User Study") as demo:
241
 
242
  next_button = gr.Button("Next Sample", visible=False)
243
 
244
- # --- Event Handlers ---
245
-
246
- # Load first sample on page load
247
  demo.load(
248
  fn=display_new_sample,
249
  inputs=[app_state, available_samples_state],
@@ -255,23 +322,20 @@ with gr.Blocks(title="Image Composition User Study") as demo:
255
  ]
256
  )
257
 
258
- # Handle choice A click - Use the new handler function
259
  choice_button_a.click(
260
- fn=handle_choice_a, # Use the dedicated handler
261
- inputs=[app_state], # Input is still just the state component
262
  outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state],
263
  api_name=False,
264
  )
265
 
266
- # Handle choice B click - Use the new handler function
267
  choice_button_b.click(
268
- fn=handle_choice_b, # Use the dedicated handler
269
- inputs=[app_state], # Input is still just the state component
270
  outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state],
271
  api_name=False,
272
  )
273
 
274
- # Handle next sample click
275
  next_button.click(
276
  fn=display_new_sample,
277
  inputs=[app_state, available_samples_state],
@@ -282,16 +346,45 @@ with gr.Blocks(title="Image Composition User Study") as demo:
282
  app_state, available_samples_state
283
  ],
284
  api_name=False,
285
- # queue=True
286
  )
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  if __name__ == "__main__":
 
 
289
  if not INITIAL_SAMPLE_IDS:
290
  print("Error: No valid samples found in the 'data' directory.")
291
  print("Please ensure the 'data' directory exists and contains subdirectories")
292
  print("named like 'sample_id', each with 'prompt.txt', 'input_bg.*',")
293
  print("'input_fg.*', 'baseline.*', and 'tf-icon.*' files.")
 
 
 
 
 
 
 
 
 
 
 
294
  else:
 
295
  print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.")
296
  print("Starting Gradio app...")
297
  demo.launch(server_name="0.0.0.0")
 
3
  import random
4
  import csv
5
  from pathlib import Path
6
+ from datetime import datetime, timedelta
7
+ import tempfile
8
+ from huggingface_hub import HfApi, hf_hub_download, login
9
+ from huggingface_hub.utils import RepositoryNotFoundError, EntryNotFoundError
10
+ from apscheduler.schedulers.background import BackgroundScheduler
11
+ import atexit
12
+ import threading
13
+ import time
14
+ import shutil
15
+
16
+ # --- Configuration ---
17
+ DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "matsant01/user-study-collected-preferences")
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+ RESULTS_FILENAME_IN_REPO = "preferences.csv"
20
+ TEMP_DIR = tempfile.mkdtemp()
21
+ LOCAL_RESULTS_FILE = Path(TEMP_DIR) / RESULTS_FILENAME_IN_REPO
22
+ UPLOAD_INTERVAL_HOURS = 0.1
23
 
24
  DATA_DIR = Path("data")
 
 
25
  IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"]
26
 
27
+ # --- Global State for Upload Logic ---
28
+ hf_api = None
29
+ scheduler = BackgroundScheduler(daemon=True)
30
+ upload_lock = threading.Lock()
31
+ new_preferences_recorded_since_last_upload = threading.Event()
32
+
33
+ # --- Hugging Face Hub Login & Initialization ---
34
+ def initialize_hub_and_results():
35
+ global hf_api
36
+ if HF_TOKEN:
37
+ print("Logging into Hugging Face Hub...")
38
+ try:
39
+ login(token=HF_TOKEN)
40
+ hf_api = HfApi()
41
+ print(f"Attempting initial download of {RESULTS_FILENAME_IN_REPO} from {DATASET_REPO_ID}")
42
+ hf_hub_download(
43
+ repo_id=DATASET_REPO_ID,
44
+ filename=RESULTS_FILENAME_IN_REPO,
45
+ repo_type="dataset",
46
+ token=HF_TOKEN,
47
+ local_dir=TEMP_DIR,
48
+ local_dir_use_symlinks=False
49
+ )
50
+ print(f"Successfully downloaded existing {RESULTS_FILENAME_IN_REPO} to {LOCAL_RESULTS_FILE}")
51
+ except EntryNotFoundError:
52
+ print(f"{RESULTS_FILENAME_IN_REPO} not found in repo. Will create locally.")
53
+ except RepositoryNotFoundError:
54
+ print(f"Error: Dataset repository {DATASET_REPO_ID} not found or token lacks permissions.")
55
+ print("Results saving will be disabled.")
56
+ hf_api = None
57
+ except Exception as e:
58
+ print(f"Error during initial download/login: {e}")
59
+ print("Proceeding without initial download. File will be created locally.")
60
+ else:
61
+ print("Warning: HF_TOKEN secret not found. Results will not be saved to the Hub.")
62
+ hf_api = None
63
+
64
  # --- Data Loading ---
65
 
66
  def find_image(folder_path: Path, base_name: str) -> Path | None:
 
67
  for ext in IMAGE_EXTENSIONS:
68
  file_path = folder_path / f"{base_name}{ext}"
69
  if file_path.exists():
 
71
  return None
72
 
73
  def get_sample_ids() -> list[str]:
 
74
  sample_ids = []
75
  if DATA_DIR.is_dir():
76
  for item in DATA_DIR.iterdir():
77
  if item.is_dir():
 
78
  prompt_file = item / "prompt.txt"
79
  input_bg = find_image(item, "input_bg")
80
  input_fg = find_image(item, "input_fg")
 
85
  return sample_ids
86
 
87
  def load_sample_data(sample_id: str) -> dict | None:
 
88
  sample_path = DATA_DIR / sample_id
89
  if not sample_path.is_dir():
90
  return None
 
119
  INITIAL_SAMPLE_IDS = get_sample_ids()
120
 
121
  def get_next_sample(available_ids: list[str]) -> tuple[dict | None, list[str]]:
 
122
  if not available_ids:
123
  return None, []
124
  chosen_id = random.choice(available_ids)
 
126
  sample_data = load_sample_data(chosen_id)
127
  return sample_data, remaining_ids
128
 
 
129
  def display_new_sample(state: dict, available_ids: list[str]):
 
130
  sample_data, remaining_ids = get_next_sample(available_ids)
131
 
132
  if not sample_data:
 
173
  }
174
 
175
  def record_preference(choice: str, state: dict, request: gr.Request):
176
+ if not request:
 
177
  print("Error: Request object is None. Cannot get session ID.")
178
+ session_id = "unknown_session"
179
  else:
180
  try:
181
+ session_id = request.client.host
182
  except AttributeError:
183
+ print("Error: request.client is None or has no 'host' attribute.")
184
+ session_id = "unknown_client"
185
 
186
  if not state or "current_sample_id" not in state:
187
  print("Warning: State missing, cannot record preference.")
 
190
  choice_button_b: gr.update(interactive=False),
191
  next_button: gr.update(visible=True, interactive=True),
192
  status_display: gr.update(value="Error: Session state lost. Click Next Sample."),
193
+ app_state: state
194
  }
195
 
196
  chosen_model_name = state["output_a_model_name"] if choice == "A" else state["output_b_model_name"]
197
+ baseline_display = "A" if state["output_a_model_name"] == "baseline" else "B"
198
+ tficon_display = "B" if state["output_a_model_name"] == "baseline" else "A"
199
+
200
+ new_row = {
201
+ "timestamp": datetime.now().isoformat(),
202
+ "session_id": session_id,
203
+ "sample_id": state["current_sample_id"],
204
+ "baseline_displayed_as": baseline_display,
205
+ "tficon_displayed_as": tficon_display,
206
+ "chosen_display": choice,
207
+ "chosen_model_name": chosen_model_name
208
+ }
209
+ header = list(new_row.keys())
210
 
 
 
 
 
 
211
  try:
212
+ with upload_lock:
213
+ file_exists = LOCAL_RESULTS_FILE.exists()
214
+ mode = 'a' if file_exists else 'w'
215
+ with open(LOCAL_RESULTS_FILE, mode, newline='', encoding='utf-8') as f:
216
+ writer = csv.DictWriter(f, fieldnames=header)
217
+ if not file_exists or os.path.getsize(LOCAL_RESULTS_FILE) == 0:
218
+ writer.writeheader()
219
+ print(f"Created or wrote header to {LOCAL_RESULTS_FILE}")
220
+ writer.writerow(new_row)
221
+ print(f"Appended preference for {state['current_sample_id']} to local file.")
222
+ new_preferences_recorded_since_last_upload.set()
223
+
 
 
 
 
 
 
 
 
 
224
  except Exception as e:
225
+ print(f"Error writing local results file {LOCAL_RESULTS_FILE}: {e}")
226
  return {
227
  choice_button_a: gr.update(interactive=False),
228
  choice_button_b: gr.update(interactive=False),
229
+ next_button: gr.update(visible=True, interactive=True),
230
+ status_display: gr.update(value=f"Error saving preference locally: {e}. Click Next."),
231
  app_state: state
232
  }
233
 
 
 
234
  return {
235
  choice_button_a: gr.update(interactive=False),
236
  choice_button_b: gr.update(interactive=False),
237
  next_button: gr.update(visible=True, interactive=True),
238
  status_display: gr.update(value=f"Preference recorded (Chose {choice}). Click Next Sample."),
239
+ app_state: state
240
  }
241
 
242
+ def upload_preferences_to_hub():
243
+ print("Periodic upload check triggered.")
244
+ if not hf_api:
245
+ print("Upload check skipped: Hugging Face API not available.")
246
+ return
247
+
248
+ if not new_preferences_recorded_since_last_upload.is_set():
249
+ print("Upload check skipped: No new preferences recorded since last upload.")
250
+ return
251
+
252
+ with upload_lock:
253
+ if not new_preferences_recorded_since_last_upload.is_set():
254
+ print("Upload check skipped (race condition avoided): No new preferences.")
255
+ return
256
+
257
+ if not LOCAL_RESULTS_FILE.exists() or os.path.getsize(LOCAL_RESULTS_FILE) == 0:
258
+ print("Upload check skipped: Local results file is missing or empty.")
259
+ new_preferences_recorded_since_last_upload.clear()
260
+ return
261
+
262
+ try:
263
+ print(f"Attempting to upload {LOCAL_RESULTS_FILE} to {DATASET_REPO_ID}/{RESULTS_FILENAME_IN_REPO}")
264
+ start_time = time.time()
265
+ hf_api.upload_file(
266
+ path_or_fileobj=str(LOCAL_RESULTS_FILE),
267
+ path_in_repo=RESULTS_FILENAME_IN_REPO,
268
+ repo_id=DATASET_REPO_ID,
269
+ repo_type="dataset",
270
+ commit_message=f"Periodic upload of preferences - {datetime.now().isoformat()}"
271
+ )
272
+ end_time = time.time()
273
+ print(f"Successfully uploaded preferences. Took {end_time - start_time:.2f} seconds.")
274
+ new_preferences_recorded_since_last_upload.clear()
275
+ except Exception as e:
276
+ print(f"Error uploading results file: {e}")
277
+
278
  def handle_choice_a(state: dict, request: gr.Request):
279
  return record_preference("A", state, request)
280
 
281
  def handle_choice_b(state: dict, request: gr.Request):
282
  return record_preference("B", state, request)
283
 
 
 
284
  with gr.Blocks(title="Image Composition User Study") as demo:
285
  gr.Markdown("# Image Composition User Study")
286
  gr.Markdown(
 
288
  "Then, compare the two output images (Output A and Output B) and click the button below the one you prefer."
289
  )
290
 
291
+ app_state = gr.State({})
 
 
292
  available_samples_state = gr.State(INITIAL_SAMPLE_IDS)
293
 
 
294
  prompt_display = gr.Textbox(label="Prompt", interactive=False)
295
  status_display = gr.Textbox(label="Status", value="Loading first sample...", interactive=False)
296
 
 
311
 
312
  next_button = gr.Button("Next Sample", visible=False)
313
 
 
 
 
314
  demo.load(
315
  fn=display_new_sample,
316
  inputs=[app_state, available_samples_state],
 
322
  ]
323
  )
324
 
 
325
  choice_button_a.click(
326
+ fn=handle_choice_a,
327
+ inputs=[app_state],
328
  outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state],
329
  api_name=False,
330
  )
331
 
 
332
  choice_button_b.click(
333
+ fn=handle_choice_b,
334
+ inputs=[app_state],
335
  outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state],
336
  api_name=False,
337
  )
338
 
 
339
  next_button.click(
340
  fn=display_new_sample,
341
  inputs=[app_state, available_samples_state],
 
346
  app_state, available_samples_state
347
  ],
348
  api_name=False,
 
349
  )
350
 
351
+ def cleanup_temp_dir():
352
+ if Path(TEMP_DIR).exists():
353
+ print(f"Cleaning up temporary directory: {TEMP_DIR}")
354
+ shutil.rmtree(TEMP_DIR, ignore_errors=True)
355
+
356
+ def shutdown_hook():
357
+ print("Application shutting down. Performing final upload check...")
358
+ upload_preferences_to_hub()
359
+ if scheduler.running:
360
+ print("Shutting down scheduler...")
361
+ scheduler.shutdown(wait=False)
362
+ cleanup_temp_dir()
363
+ print("Shutdown complete.")
364
+
365
+ atexit.register(shutdown_hook)
366
+
367
  if __name__ == "__main__":
368
+ initialize_hub_and_results()
369
+
370
  if not INITIAL_SAMPLE_IDS:
371
  print("Error: No valid samples found in the 'data' directory.")
372
  print("Please ensure the 'data' directory exists and contains subdirectories")
373
  print("named like 'sample_id', each with 'prompt.txt', 'input_bg.*',")
374
  print("'input_fg.*', 'baseline.*', and 'tf-icon.*' files.")
375
+ elif not DATASET_REPO_ID:
376
+ print("Error: DATASET_REPO_ID environment variable is not set or is set to the default placeholder.")
377
+ print("Please set the DATASET_REPO_ID environment variable or update the script.")
378
+ elif hf_api:
379
+ print(f"Starting periodic upload scheduler (every {UPLOAD_INTERVAL_HOURS} hours)...")
380
+ scheduler.add_job(upload_preferences_to_hub, 'interval', hours=UPLOAD_INTERVAL_HOURS)
381
+ scheduler.start()
382
+ print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.")
383
+ print(f"Configured to save results periodically to Hugging Face Dataset: {DATASET_REPO_ID}")
384
+ print("Starting Gradio app...")
385
+ demo.launch(server_name="0.0.0.0")
386
  else:
387
+ print("Warning: Running without Hugging Face Hub integration (HF_TOKEN or DATASET_REPO_ID missing/invalid).")
388
  print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.")
389
  print("Starting Gradio app...")
390
  demo.launch(server_name="0.0.0.0")
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- gradio
 
 
 
1
+ gradio
2
+ huggingface_hub
3
+ apscheduler # Added for periodic tasks