Spaces:
Running
Running
import gradio as gr | |
import os | |
import random | |
import csv | |
from pathlib import Path | |
from datetime import datetime | |
DATA_DIR = Path("data") | |
RESULTS_DIR = Path("results") | |
RESULTS_FILE = RESULTS_DIR / "preferences.csv" | |
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"] | |
# --- Data Loading --- | |
def find_image(folder_path: Path, base_name: str) -> Path | None: | |
"""Finds an image file starting with base_name in a folder.""" | |
for ext in IMAGE_EXTENSIONS: | |
file_path = folder_path / f"{base_name}{ext}" | |
if file_path.exists(): | |
return file_path | |
return None | |
def get_sample_ids() -> list[str]: | |
"""Scans the data directory for valid sample IDs.""" | |
sample_ids = [] | |
if DATA_DIR.is_dir(): | |
for item in DATA_DIR.iterdir(): | |
if item.is_dir(): | |
# Check if required files exist | |
prompt_file = item / "prompt.txt" | |
input_bg = find_image(item, "input_bg") | |
input_fg = find_image(item, "input_fg") | |
output_baseline = find_image(item, "baseline") | |
output_tficon = find_image(item, "tf-icon") | |
if prompt_file.exists() and input_bg and input_fg and output_baseline and output_tficon: | |
sample_ids.append(item.name) | |
return sample_ids | |
def load_sample_data(sample_id: str) -> dict | None: | |
"""Loads data for a specific sample ID.""" | |
sample_path = DATA_DIR / sample_id | |
if not sample_path.is_dir(): | |
return None | |
prompt_file = sample_path / "prompt.txt" | |
input_bg_path = find_image(sample_path, "input_bg") | |
input_fg_path = find_image(sample_path, "input_fg") | |
output_baseline_path = find_image(sample_path, "baseline") | |
output_tficon_path = find_image(sample_path, "tf-icon") | |
if not all([prompt_file.exists(), input_bg_path, input_fg_path, output_baseline_path, output_tficon_path]): | |
print(f"Warning: Missing files in sample {sample_id}") | |
return None | |
try: | |
prompt = prompt_file.read_text().strip() | |
except Exception as e: | |
print(f"Error reading prompt for {sample_id}: {e}") | |
return None | |
return { | |
"id": sample_id, | |
"prompt": prompt, | |
"input_bg": str(input_bg_path), | |
"input_fg": str(input_fg_path), | |
"output_baseline": str(output_baseline_path), | |
"output_tficon": str(output_tficon_path), | |
} | |
# --- State and UI Logic --- | |
INITIAL_SAMPLE_IDS = get_sample_ids() | |
def get_next_sample(available_ids: list[str]) -> tuple[dict | None, list[str]]: | |
"""Selects a random sample ID from the available list.""" | |
if not available_ids: | |
return None, [] | |
chosen_id = random.choice(available_ids) | |
remaining_ids = [id for id in available_ids if id != chosen_id] | |
sample_data = load_sample_data(chosen_id) | |
return sample_data, remaining_ids | |
def display_new_sample(state: dict, available_ids: list[str]): | |
"""Loads and prepares a new sample for display.""" | |
sample_data, remaining_ids = get_next_sample(available_ids) | |
if not sample_data: | |
return { | |
prompt_display: gr.update(value="No more samples available. Thank you!"), | |
input_bg_display: gr.update(value=None, visible=False), | |
input_fg_display: gr.update(value=None, visible=False), | |
output_a_display: gr.update(value=None, visible=False), | |
output_b_display: gr.update(value=None, visible=False), | |
choice_button_a: gr.update(visible=False), | |
choice_button_b: gr.update(visible=False), | |
next_button: gr.update(visible=False), | |
status_display: gr.update(value="Completed!"), | |
app_state: state, | |
available_samples_state: remaining_ids | |
} | |
outputs = [ | |
{"model_name": "baseline", "path": sample_data["output_baseline"]}, | |
{"model_name": "tf-icon", "path": sample_data["output_tficon"]}, | |
] | |
random.shuffle(outputs) | |
output_a = outputs[0] | |
output_b = outputs[1] | |
state = { | |
"current_sample_id": sample_data["id"], | |
"output_a_model_name": output_a["model_name"], | |
"output_b_model_name": output_b["model_name"], | |
} | |
return { | |
prompt_display: gr.update(value=f"Prompt: {sample_data['prompt']}"), | |
input_bg_display: gr.update(value=sample_data["input_bg"], visible=True), | |
input_fg_display: gr.update(value=sample_data["input_fg"], visible=True), | |
output_a_display: gr.update(value=output_a["path"], visible=True), | |
output_b_display: gr.update(value=output_b["path"], visible=True), | |
choice_button_a: gr.update(visible=True, interactive=True), | |
choice_button_b: gr.update(visible=True, interactive=True), | |
next_button: gr.update(visible=False), | |
status_display: gr.update(value="Please choose the image you prefer."), | |
app_state: state, | |
available_samples_state: remaining_ids | |
} | |
def record_preference(choice: str, state: dict, request: gr.Request): | |
"""Records the user's preference and prepares for the next sample.""" | |
if not request: # Add a check if request is None | |
print("Error: Request object is None. Cannot get session ID.") | |
session_id = "unknown_session" # Fallback session ID | |
else: | |
try: | |
session_id = request.client.host # Use IP address as a basic session identifier | |
except AttributeError: | |
print("Error: request.client is None or has no 'host' attribute.") | |
session_id = "unknown_client" # Fallback if client object is weird | |
if not state or "current_sample_id" not in state: | |
print("Warning: State missing, cannot record preference.") | |
return { | |
choice_button_a: gr.update(interactive=False), | |
choice_button_b: gr.update(interactive=False), | |
next_button: gr.update(visible=True, interactive=True), | |
status_display: gr.update(value="Error: Session state lost. Click Next Sample."), | |
app_state: state # Return unchanged state | |
} | |
chosen_model_name = state["output_a_model_name"] if choice == "A" else state["output_b_model_name"] | |
# Ensure results directory exists | |
RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
# Append result to CSV | |
file_exists = RESULTS_FILE.exists() | |
try: | |
with open(RESULTS_FILE, 'a', newline='', encoding='utf-8') as f: | |
writer = csv.writer(f) | |
if not file_exists: | |
writer.writerow([ | |
"timestamp", "session_id", "sample_id", | |
"baseline_displayed_as", "tficon_displayed_as", | |
"chosen_display", "chosen_model_name" | |
]) # Header | |
baseline_display = "A" if state["output_a_model_name"] == "baseline" else "B" | |
tficon_display = "B" if state["output_a_model_name"] == "baseline" else "A" | |
writer.writerow([ | |
datetime.now().isoformat(), | |
session_id, | |
state["current_sample_id"], | |
baseline_display, | |
tficon_display, | |
choice, # A or B | |
chosen_model_name # baseline or tf-icon | |
]) | |
except Exception as e: | |
print(f"Error writing results: {e}") | |
return { | |
choice_button_a: gr.update(interactive=False), | |
choice_button_b: gr.update(interactive=False), | |
next_button: gr.update(visible=True, interactive=True), # Allow user to continue | |
status_display: gr.update(value=f"Error saving preference: {e}. Click Next Sample."), | |
app_state: state | |
} | |
# Update UI: disable choice buttons, show next button | |
return { | |
choice_button_a: gr.update(interactive=False), | |
choice_button_b: gr.update(interactive=False), | |
next_button: gr.update(visible=True, interactive=True), | |
status_display: gr.update(value=f"Preference recorded (Chose {choice}). Click Next Sample."), | |
app_state: state # Return unchanged state | |
} | |
# --- New Handler Functions --- | |
def handle_choice_a(state: dict, request: gr.Request): | |
return record_preference("A", state, request) | |
def handle_choice_b(state: dict, request: gr.Request): | |
return record_preference("B", state, request) | |
# --- Gradio Interface --- | |
with gr.Blocks(title="Image Composition User Study") as demo: | |
gr.Markdown("# Image Composition User Study") | |
gr.Markdown( | |
"Please look at the input images and the prompt below. " | |
"Then, compare the two output images (Output A and Output B) and click the button below the one you prefer." | |
) | |
# State variables | |
app_state = gr.State({}) # Stores current sample info (id, output mapping) | |
# Keep track of samples available *for this session* | |
available_samples_state = gr.State(INITIAL_SAMPLE_IDS) | |
# Displays | |
prompt_display = gr.Textbox(label="Prompt", interactive=False) | |
status_display = gr.Textbox(label="Status", value="Loading first sample...", interactive=False) | |
with gr.Row(): | |
input_bg_display = gr.Image(label="Input Background", type="filepath", height=300, width=300, interactive=False) | |
input_fg_display = gr.Image(label="Input Foreground", type="filepath", height=300, width=300, interactive=False) | |
gr.Markdown("---") | |
gr.Markdown("## Choose your preferred output:") | |
with gr.Row(): | |
with gr.Column(): | |
output_a_display = gr.Image(label="Output A", type="filepath", height=400, width=400, interactive=False) | |
choice_button_a = gr.Button("Choose Output A", variant="primary") | |
with gr.Column(): | |
output_b_display = gr.Image(label="Output B", type="filepath", height=400, width=400, interactive=False) | |
choice_button_b = gr.Button("Choose Output B", variant="primary") | |
next_button = gr.Button("Next Sample", visible=False) | |
# --- Event Handlers --- | |
# Load first sample on page load | |
demo.load( | |
fn=display_new_sample, | |
inputs=[app_state, available_samples_state], | |
outputs=[ | |
prompt_display, input_bg_display, input_fg_display, | |
output_a_display, output_b_display, | |
choice_button_a, choice_button_b, next_button, status_display, | |
app_state, available_samples_state | |
] | |
) | |
# Handle choice A click - Use the new handler function | |
choice_button_a.click( | |
fn=handle_choice_a, # Use the dedicated handler | |
inputs=[app_state], # Input is still just the state component | |
outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state], | |
api_name=False, | |
) | |
# Handle choice B click - Use the new handler function | |
choice_button_b.click( | |
fn=handle_choice_b, # Use the dedicated handler | |
inputs=[app_state], # Input is still just the state component | |
outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state], | |
api_name=False, | |
) | |
# Handle next sample click | |
next_button.click( | |
fn=display_new_sample, | |
inputs=[app_state, available_samples_state], | |
outputs=[ | |
prompt_display, input_bg_display, input_fg_display, | |
output_a_display, output_b_display, | |
choice_button_a, choice_button_b, next_button, status_display, | |
app_state, available_samples_state | |
], | |
api_name=False, | |
# queue=True | |
) | |
if __name__ == "__main__": | |
if not INITIAL_SAMPLE_IDS: | |
print("Error: No valid samples found in the 'data' directory.") | |
print("Please ensure the 'data' directory exists and contains subdirectories") | |
print("named like 'sample_id', each with 'prompt.txt', 'input_bg.*',") | |
print("'input_fg.*', 'baseline.*', and 'tf-icon.*' files.") | |
else: | |
print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.") | |
print("Starting Gradio app...") | |
demo.launch(server_name="0.0.0.0") |