matsant01's picture
Create app.py
540a985 verified
raw
history blame
12 kB
import gradio as gr
import os
import random
import csv
from pathlib import Path
from datetime import datetime
DATA_DIR = Path("data")
RESULTS_DIR = Path("results")
RESULTS_FILE = RESULTS_DIR / "preferences.csv"
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"]
# --- Data Loading ---
def find_image(folder_path: Path, base_name: str) -> Path | None:
"""Finds an image file starting with base_name in a folder."""
for ext in IMAGE_EXTENSIONS:
file_path = folder_path / f"{base_name}{ext}"
if file_path.exists():
return file_path
return None
def get_sample_ids() -> list[str]:
"""Scans the data directory for valid sample IDs."""
sample_ids = []
if DATA_DIR.is_dir():
for item in DATA_DIR.iterdir():
if item.is_dir():
# Check if required files exist
prompt_file = item / "prompt.txt"
input_bg = find_image(item, "input_bg")
input_fg = find_image(item, "input_fg")
output_baseline = find_image(item, "baseline")
output_tficon = find_image(item, "tf-icon")
if prompt_file.exists() and input_bg and input_fg and output_baseline and output_tficon:
sample_ids.append(item.name)
return sample_ids
def load_sample_data(sample_id: str) -> dict | None:
"""Loads data for a specific sample ID."""
sample_path = DATA_DIR / sample_id
if not sample_path.is_dir():
return None
prompt_file = sample_path / "prompt.txt"
input_bg_path = find_image(sample_path, "input_bg")
input_fg_path = find_image(sample_path, "input_fg")
output_baseline_path = find_image(sample_path, "baseline")
output_tficon_path = find_image(sample_path, "tf-icon")
if not all([prompt_file.exists(), input_bg_path, input_fg_path, output_baseline_path, output_tficon_path]):
print(f"Warning: Missing files in sample {sample_id}")
return None
try:
prompt = prompt_file.read_text().strip()
except Exception as e:
print(f"Error reading prompt for {sample_id}: {e}")
return None
return {
"id": sample_id,
"prompt": prompt,
"input_bg": str(input_bg_path),
"input_fg": str(input_fg_path),
"output_baseline": str(output_baseline_path),
"output_tficon": str(output_tficon_path),
}
# --- State and UI Logic ---
INITIAL_SAMPLE_IDS = get_sample_ids()
def get_next_sample(available_ids: list[str]) -> tuple[dict | None, list[str]]:
"""Selects a random sample ID from the available list."""
if not available_ids:
return None, []
chosen_id = random.choice(available_ids)
remaining_ids = [id for id in available_ids if id != chosen_id]
sample_data = load_sample_data(chosen_id)
return sample_data, remaining_ids
def display_new_sample(state: dict, available_ids: list[str]):
"""Loads and prepares a new sample for display."""
sample_data, remaining_ids = get_next_sample(available_ids)
if not sample_data:
return {
prompt_display: gr.update(value="No more samples available. Thank you!"),
input_bg_display: gr.update(value=None, visible=False),
input_fg_display: gr.update(value=None, visible=False),
output_a_display: gr.update(value=None, visible=False),
output_b_display: gr.update(value=None, visible=False),
choice_button_a: gr.update(visible=False),
choice_button_b: gr.update(visible=False),
next_button: gr.update(visible=False),
status_display: gr.update(value="Completed!"),
app_state: state,
available_samples_state: remaining_ids
}
outputs = [
{"model_name": "baseline", "path": sample_data["output_baseline"]},
{"model_name": "tf-icon", "path": sample_data["output_tficon"]},
]
random.shuffle(outputs)
output_a = outputs[0]
output_b = outputs[1]
state = {
"current_sample_id": sample_data["id"],
"output_a_model_name": output_a["model_name"],
"output_b_model_name": output_b["model_name"],
}
return {
prompt_display: gr.update(value=f"Prompt: {sample_data['prompt']}"),
input_bg_display: gr.update(value=sample_data["input_bg"], visible=True),
input_fg_display: gr.update(value=sample_data["input_fg"], visible=True),
output_a_display: gr.update(value=output_a["path"], visible=True),
output_b_display: gr.update(value=output_b["path"], visible=True),
choice_button_a: gr.update(visible=True, interactive=True),
choice_button_b: gr.update(visible=True, interactive=True),
next_button: gr.update(visible=False),
status_display: gr.update(value="Please choose the image you prefer."),
app_state: state,
available_samples_state: remaining_ids
}
def record_preference(choice: str, state: dict, request: gr.Request):
"""Records the user's preference and prepares for the next sample."""
if not request: # Add a check if request is None
print("Error: Request object is None. Cannot get session ID.")
session_id = "unknown_session" # Fallback session ID
else:
try:
session_id = request.client.host # Use IP address as a basic session identifier
except AttributeError:
print("Error: request.client is None or has no 'host' attribute.")
session_id = "unknown_client" # Fallback if client object is weird
if not state or "current_sample_id" not in state:
print("Warning: State missing, cannot record preference.")
return {
choice_button_a: gr.update(interactive=False),
choice_button_b: gr.update(interactive=False),
next_button: gr.update(visible=True, interactive=True),
status_display: gr.update(value="Error: Session state lost. Click Next Sample."),
app_state: state # Return unchanged state
}
chosen_model_name = state["output_a_model_name"] if choice == "A" else state["output_b_model_name"]
# Ensure results directory exists
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
# Append result to CSV
file_exists = RESULTS_FILE.exists()
try:
with open(RESULTS_FILE, 'a', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
if not file_exists:
writer.writerow([
"timestamp", "session_id", "sample_id",
"baseline_displayed_as", "tficon_displayed_as",
"chosen_display", "chosen_model_name"
]) # Header
baseline_display = "A" if state["output_a_model_name"] == "baseline" else "B"
tficon_display = "B" if state["output_a_model_name"] == "baseline" else "A"
writer.writerow([
datetime.now().isoformat(),
session_id,
state["current_sample_id"],
baseline_display,
tficon_display,
choice, # A or B
chosen_model_name # baseline or tf-icon
])
except Exception as e:
print(f"Error writing results: {e}")
return {
choice_button_a: gr.update(interactive=False),
choice_button_b: gr.update(interactive=False),
next_button: gr.update(visible=True, interactive=True), # Allow user to continue
status_display: gr.update(value=f"Error saving preference: {e}. Click Next Sample."),
app_state: state
}
# Update UI: disable choice buttons, show next button
return {
choice_button_a: gr.update(interactive=False),
choice_button_b: gr.update(interactive=False),
next_button: gr.update(visible=True, interactive=True),
status_display: gr.update(value=f"Preference recorded (Chose {choice}). Click Next Sample."),
app_state: state # Return unchanged state
}
# --- New Handler Functions ---
def handle_choice_a(state: dict, request: gr.Request):
return record_preference("A", state, request)
def handle_choice_b(state: dict, request: gr.Request):
return record_preference("B", state, request)
# --- Gradio Interface ---
with gr.Blocks(title="Image Composition User Study") as demo:
gr.Markdown("# Image Composition User Study")
gr.Markdown(
"Please look at the input images and the prompt below. "
"Then, compare the two output images (Output A and Output B) and click the button below the one you prefer."
)
# State variables
app_state = gr.State({}) # Stores current sample info (id, output mapping)
# Keep track of samples available *for this session*
available_samples_state = gr.State(INITIAL_SAMPLE_IDS)
# Displays
prompt_display = gr.Textbox(label="Prompt", interactive=False)
status_display = gr.Textbox(label="Status", value="Loading first sample...", interactive=False)
with gr.Row():
input_bg_display = gr.Image(label="Input Background", type="filepath", height=300, width=300, interactive=False)
input_fg_display = gr.Image(label="Input Foreground", type="filepath", height=300, width=300, interactive=False)
gr.Markdown("---")
gr.Markdown("## Choose your preferred output:")
with gr.Row():
with gr.Column():
output_a_display = gr.Image(label="Output A", type="filepath", height=400, width=400, interactive=False)
choice_button_a = gr.Button("Choose Output A", variant="primary")
with gr.Column():
output_b_display = gr.Image(label="Output B", type="filepath", height=400, width=400, interactive=False)
choice_button_b = gr.Button("Choose Output B", variant="primary")
next_button = gr.Button("Next Sample", visible=False)
# --- Event Handlers ---
# Load first sample on page load
demo.load(
fn=display_new_sample,
inputs=[app_state, available_samples_state],
outputs=[
prompt_display, input_bg_display, input_fg_display,
output_a_display, output_b_display,
choice_button_a, choice_button_b, next_button, status_display,
app_state, available_samples_state
]
)
# Handle choice A click - Use the new handler function
choice_button_a.click(
fn=handle_choice_a, # Use the dedicated handler
inputs=[app_state], # Input is still just the state component
outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state],
api_name=False,
)
# Handle choice B click - Use the new handler function
choice_button_b.click(
fn=handle_choice_b, # Use the dedicated handler
inputs=[app_state], # Input is still just the state component
outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state],
api_name=False,
)
# Handle next sample click
next_button.click(
fn=display_new_sample,
inputs=[app_state, available_samples_state],
outputs=[
prompt_display, input_bg_display, input_fg_display,
output_a_display, output_b_display,
choice_button_a, choice_button_b, next_button, status_display,
app_state, available_samples_state
],
api_name=False,
# queue=True
)
if __name__ == "__main__":
if not INITIAL_SAMPLE_IDS:
print("Error: No valid samples found in the 'data' directory.")
print("Please ensure the 'data' directory exists and contains subdirectories")
print("named like 'sample_id', each with 'prompt.txt', 'input_bg.*',")
print("'input_fg.*', 'baseline.*', and 'tf-icon.*' files.")
else:
print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.")
print("Starting Gradio app...")
demo.launch(server_name="0.0.0.0")