matsant01 commited on
Commit
540a985
·
verified ·
1 Parent(s): 66b5780

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +297 -0
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import random
4
+ import csv
5
+ from pathlib import Path
6
+ from datetime import datetime
7
+
8
+ DATA_DIR = Path("data")
9
+ RESULTS_DIR = Path("results")
10
+ RESULTS_FILE = RESULTS_DIR / "preferences.csv"
11
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"]
12
+
13
+ # --- Data Loading ---
14
+
15
+ def find_image(folder_path: Path, base_name: str) -> Path | None:
16
+ """Finds an image file starting with base_name in a folder."""
17
+ for ext in IMAGE_EXTENSIONS:
18
+ file_path = folder_path / f"{base_name}{ext}"
19
+ if file_path.exists():
20
+ return file_path
21
+ return None
22
+
23
+ def get_sample_ids() -> list[str]:
24
+ """Scans the data directory for valid sample IDs."""
25
+ sample_ids = []
26
+ if DATA_DIR.is_dir():
27
+ for item in DATA_DIR.iterdir():
28
+ if item.is_dir():
29
+ # Check if required files exist
30
+ prompt_file = item / "prompt.txt"
31
+ input_bg = find_image(item, "input_bg")
32
+ input_fg = find_image(item, "input_fg")
33
+ output_baseline = find_image(item, "baseline")
34
+ output_tficon = find_image(item, "tf-icon")
35
+ if prompt_file.exists() and input_bg and input_fg and output_baseline and output_tficon:
36
+ sample_ids.append(item.name)
37
+ return sample_ids
38
+
39
+ def load_sample_data(sample_id: str) -> dict | None:
40
+ """Loads data for a specific sample ID."""
41
+ sample_path = DATA_DIR / sample_id
42
+ if not sample_path.is_dir():
43
+ return None
44
+
45
+ prompt_file = sample_path / "prompt.txt"
46
+ input_bg_path = find_image(sample_path, "input_bg")
47
+ input_fg_path = find_image(sample_path, "input_fg")
48
+ output_baseline_path = find_image(sample_path, "baseline")
49
+ output_tficon_path = find_image(sample_path, "tf-icon")
50
+
51
+ if not all([prompt_file.exists(), input_bg_path, input_fg_path, output_baseline_path, output_tficon_path]):
52
+ print(f"Warning: Missing files in sample {sample_id}")
53
+ return None
54
+
55
+ try:
56
+ prompt = prompt_file.read_text().strip()
57
+ except Exception as e:
58
+ print(f"Error reading prompt for {sample_id}: {e}")
59
+ return None
60
+
61
+ return {
62
+ "id": sample_id,
63
+ "prompt": prompt,
64
+ "input_bg": str(input_bg_path),
65
+ "input_fg": str(input_fg_path),
66
+ "output_baseline": str(output_baseline_path),
67
+ "output_tficon": str(output_tficon_path),
68
+ }
69
+
70
+ # --- State and UI Logic ---
71
+
72
+ INITIAL_SAMPLE_IDS = get_sample_ids()
73
+
74
+ def get_next_sample(available_ids: list[str]) -> tuple[dict | None, list[str]]:
75
+ """Selects a random sample ID from the available list."""
76
+ if not available_ids:
77
+ return None, []
78
+ chosen_id = random.choice(available_ids)
79
+ remaining_ids = [id for id in available_ids if id != chosen_id]
80
+ sample_data = load_sample_data(chosen_id)
81
+ return sample_data, remaining_ids
82
+
83
+
84
+ def display_new_sample(state: dict, available_ids: list[str]):
85
+ """Loads and prepares a new sample for display."""
86
+ sample_data, remaining_ids = get_next_sample(available_ids)
87
+
88
+ if not sample_data:
89
+ return {
90
+ prompt_display: gr.update(value="No more samples available. Thank you!"),
91
+ input_bg_display: gr.update(value=None, visible=False),
92
+ input_fg_display: gr.update(value=None, visible=False),
93
+ output_a_display: gr.update(value=None, visible=False),
94
+ output_b_display: gr.update(value=None, visible=False),
95
+ choice_button_a: gr.update(visible=False),
96
+ choice_button_b: gr.update(visible=False),
97
+ next_button: gr.update(visible=False),
98
+ status_display: gr.update(value="Completed!"),
99
+ app_state: state,
100
+ available_samples_state: remaining_ids
101
+ }
102
+
103
+ outputs = [
104
+ {"model_name": "baseline", "path": sample_data["output_baseline"]},
105
+ {"model_name": "tf-icon", "path": sample_data["output_tficon"]},
106
+ ]
107
+ random.shuffle(outputs)
108
+ output_a = outputs[0]
109
+ output_b = outputs[1]
110
+
111
+ state = {
112
+ "current_sample_id": sample_data["id"],
113
+ "output_a_model_name": output_a["model_name"],
114
+ "output_b_model_name": output_b["model_name"],
115
+ }
116
+
117
+ return {
118
+ prompt_display: gr.update(value=f"Prompt: {sample_data['prompt']}"),
119
+ input_bg_display: gr.update(value=sample_data["input_bg"], visible=True),
120
+ input_fg_display: gr.update(value=sample_data["input_fg"], visible=True),
121
+ output_a_display: gr.update(value=output_a["path"], visible=True),
122
+ output_b_display: gr.update(value=output_b["path"], visible=True),
123
+ choice_button_a: gr.update(visible=True, interactive=True),
124
+ choice_button_b: gr.update(visible=True, interactive=True),
125
+ next_button: gr.update(visible=False),
126
+ status_display: gr.update(value="Please choose the image you prefer."),
127
+ app_state: state,
128
+ available_samples_state: remaining_ids
129
+ }
130
+
131
+ def record_preference(choice: str, state: dict, request: gr.Request):
132
+ """Records the user's preference and prepares for the next sample."""
133
+ if not request: # Add a check if request is None
134
+ print("Error: Request object is None. Cannot get session ID.")
135
+ session_id = "unknown_session" # Fallback session ID
136
+ else:
137
+ try:
138
+ session_id = request.client.host # Use IP address as a basic session identifier
139
+ except AttributeError:
140
+ print("Error: request.client is None or has no 'host' attribute.")
141
+ session_id = "unknown_client" # Fallback if client object is weird
142
+
143
+ if not state or "current_sample_id" not in state:
144
+ print("Warning: State missing, cannot record preference.")
145
+ return {
146
+ choice_button_a: gr.update(interactive=False),
147
+ choice_button_b: gr.update(interactive=False),
148
+ next_button: gr.update(visible=True, interactive=True),
149
+ status_display: gr.update(value="Error: Session state lost. Click Next Sample."),
150
+ app_state: state # Return unchanged state
151
+ }
152
+
153
+ chosen_model_name = state["output_a_model_name"] if choice == "A" else state["output_b_model_name"]
154
+
155
+ # Ensure results directory exists
156
+ RESULTS_DIR.mkdir(parents=True, exist_ok=True)
157
+
158
+ # Append result to CSV
159
+ file_exists = RESULTS_FILE.exists()
160
+ try:
161
+ with open(RESULTS_FILE, 'a', newline='', encoding='utf-8') as f:
162
+ writer = csv.writer(f)
163
+ if not file_exists:
164
+ writer.writerow([
165
+ "timestamp", "session_id", "sample_id",
166
+ "baseline_displayed_as", "tficon_displayed_as",
167
+ "chosen_display", "chosen_model_name"
168
+ ]) # Header
169
+
170
+ baseline_display = "A" if state["output_a_model_name"] == "baseline" else "B"
171
+ tficon_display = "B" if state["output_a_model_name"] == "baseline" else "A"
172
+
173
+ writer.writerow([
174
+ datetime.now().isoformat(),
175
+ session_id,
176
+ state["current_sample_id"],
177
+ baseline_display,
178
+ tficon_display,
179
+ choice, # A or B
180
+ chosen_model_name # baseline or tf-icon
181
+ ])
182
+ except Exception as e:
183
+ print(f"Error writing results: {e}")
184
+ return {
185
+ choice_button_a: gr.update(interactive=False),
186
+ choice_button_b: gr.update(interactive=False),
187
+ next_button: gr.update(visible=True, interactive=True), # Allow user to continue
188
+ status_display: gr.update(value=f"Error saving preference: {e}. Click Next Sample."),
189
+ app_state: state
190
+ }
191
+
192
+
193
+ # Update UI: disable choice buttons, show next button
194
+ return {
195
+ choice_button_a: gr.update(interactive=False),
196
+ choice_button_b: gr.update(interactive=False),
197
+ next_button: gr.update(visible=True, interactive=True),
198
+ status_display: gr.update(value=f"Preference recorded (Chose {choice}). Click Next Sample."),
199
+ app_state: state # Return unchanged state
200
+ }
201
+
202
+ # --- New Handler Functions ---
203
+ def handle_choice_a(state: dict, request: gr.Request):
204
+ return record_preference("A", state, request)
205
+
206
+ def handle_choice_b(state: dict, request: gr.Request):
207
+ return record_preference("B", state, request)
208
+
209
+ # --- Gradio Interface ---
210
+
211
+ with gr.Blocks(title="Image Composition User Study") as demo:
212
+ gr.Markdown("# Image Composition User Study")
213
+ gr.Markdown(
214
+ "Please look at the input images and the prompt below. "
215
+ "Then, compare the two output images (Output A and Output B) and click the button below the one you prefer."
216
+ )
217
+
218
+ # State variables
219
+ app_state = gr.State({}) # Stores current sample info (id, output mapping)
220
+ # Keep track of samples available *for this session*
221
+ available_samples_state = gr.State(INITIAL_SAMPLE_IDS)
222
+
223
+ # Displays
224
+ prompt_display = gr.Textbox(label="Prompt", interactive=False)
225
+ status_display = gr.Textbox(label="Status", value="Loading first sample...", interactive=False)
226
+
227
+ with gr.Row():
228
+ input_bg_display = gr.Image(label="Input Background", type="filepath", height=300, width=300, interactive=False)
229
+ input_fg_display = gr.Image(label="Input Foreground", type="filepath", height=300, width=300, interactive=False)
230
+
231
+ gr.Markdown("---")
232
+ gr.Markdown("## Choose your preferred output:")
233
+
234
+ with gr.Row():
235
+ with gr.Column():
236
+ output_a_display = gr.Image(label="Output A", type="filepath", height=400, width=400, interactive=False)
237
+ choice_button_a = gr.Button("Choose Output A", variant="primary")
238
+ with gr.Column():
239
+ output_b_display = gr.Image(label="Output B", type="filepath", height=400, width=400, interactive=False)
240
+ choice_button_b = gr.Button("Choose Output B", variant="primary")
241
+
242
+ next_button = gr.Button("Next Sample", visible=False)
243
+
244
+ # --- Event Handlers ---
245
+
246
+ # Load first sample on page load
247
+ demo.load(
248
+ fn=display_new_sample,
249
+ inputs=[app_state, available_samples_state],
250
+ outputs=[
251
+ prompt_display, input_bg_display, input_fg_display,
252
+ output_a_display, output_b_display,
253
+ choice_button_a, choice_button_b, next_button, status_display,
254
+ app_state, available_samples_state
255
+ ]
256
+ )
257
+
258
+ # Handle choice A click - Use the new handler function
259
+ choice_button_a.click(
260
+ fn=handle_choice_a, # Use the dedicated handler
261
+ inputs=[app_state], # Input is still just the state component
262
+ outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state],
263
+ api_name=False,
264
+ )
265
+
266
+ # Handle choice B click - Use the new handler function
267
+ choice_button_b.click(
268
+ fn=handle_choice_b, # Use the dedicated handler
269
+ inputs=[app_state], # Input is still just the state component
270
+ outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state],
271
+ api_name=False,
272
+ )
273
+
274
+ # Handle next sample click
275
+ next_button.click(
276
+ fn=display_new_sample,
277
+ inputs=[app_state, available_samples_state],
278
+ outputs=[
279
+ prompt_display, input_bg_display, input_fg_display,
280
+ output_a_display, output_b_display,
281
+ choice_button_a, choice_button_b, next_button, status_display,
282
+ app_state, available_samples_state
283
+ ],
284
+ api_name=False,
285
+ # queue=True
286
+ )
287
+
288
+ if __name__ == "__main__":
289
+ if not INITIAL_SAMPLE_IDS:
290
+ print("Error: No valid samples found in the 'data' directory.")
291
+ print("Please ensure the 'data' directory exists and contains subdirectories")
292
+ print("named like 'sample_id', each with 'prompt.txt', 'input_bg.*',")
293
+ print("'input_fg.*', 'baseline.*', and 'tf-icon.*' files.")
294
+ else:
295
+ print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.")
296
+ print("Starting Gradio app...")
297
+ demo.launch(server_name="0.0.0.0")