import gradio as gr
import pandas as pd
import os
import random
from datetime import datetime
from apscheduler.schedulers.background import BackgroundScheduler
from PIL import Image
from filelock import FileLock # Added for file locking
import config
import utils
# --- Global Variables & Initial Setup ---
# Attempt to log in to Hugging Face Hub at startup
utils.login_hugging_face()
print(f"Attempting to load preferences from Hugging Face Hub, ensuring local {config.RESULTS_CSV_FILE} is synchronized.")
# We assume utils.load_preferences_from_hf_hub:
# 1. Downloads from Hub, overwrites local config.RESULTS_CSV_FILE.
# 2. If Hub file doesn't exist, local config.RESULTS_CSV_FILE becomes empty (or reflects this).
# 3. Returns the DataFrame loaded from the (now synchronized) local file.
# 4. Returns None on major failure (e.g. network, file not found on Hub).
preferences_df = utils.load_preferences_from_hf_hub(config.HF_DATASET_REPO_ID, config.RESULTS_CSV_FILE)
if preferences_df is None:
print(f"Failed to load from Hub or Hub is empty/file not found. Initializing/loading from {config.RESULTS_CSV_FILE} as a fallback.")
if os.path.exists(config.RESULTS_CSV_FILE):
try:
preferences_df = pd.read_csv(config.RESULTS_CSV_FILE)
if not preferences_df.empty and list(preferences_df.columns) != config.CSV_HEADERS:
print(f"Warning: Local CSV {config.RESULTS_CSV_FILE} columns ({list(preferences_df.columns)}) do not match expected headers ({config.CSV_HEADERS}). Re-initializing file and DataFrame.")
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS)
preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False)
elif preferences_df.empty: # Loaded an empty DataFrame
# Check if the file itself had incorrect headers or was truly empty
current_headers = []
if os.path.getsize(config.RESULTS_CSV_FILE) > 0:
try:
current_headers = list(pd.read_csv(config.RESULTS_CSV_FILE, nrows=0).columns)
except Exception: # Handle cases where reading headers might fail
pass # Will be caught by re-initialization if headers are bad
if current_headers != config.CSV_HEADERS:
print(f"Local CSV {config.RESULTS_CSV_FILE} is empty or has incorrect headers. Re-initializing file and DataFrame with correct headers.")
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS)
preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False)
else: # Empty dataframe, but headers in file are correct
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) # Ensure in-memory df also has columns
except pd.errors.EmptyDataError:
print(f"Local CSV {config.RESULTS_CSV_FILE} is empty. Initializing file and DataFrame with headers.")
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS)
preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False)
except Exception as e:
print(f"Error loading local {config.RESULTS_CSV_FILE}: {e}. Initializing file and DataFrame.")
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS)
preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False)
else:
print(f"Local CSV {config.RESULTS_CSV_FILE} not found. Initializing file and DataFrame.")
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS)
preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False)
else:
# Successfully loaded from Hub; local file config.RESULTS_CSV_FILE should be synchronized.
print(f"Successfully loaded preferences from Hugging Face Hub. Local copy at {config.RESULTS_CSV_FILE} should be up-to-date.")
if not preferences_df.empty and list(preferences_df.columns) != config.CSV_HEADERS:
print(f"CRITICAL: Data from Hub has incorrect columns {list(preferences_df.columns)}. Expected {config.CSV_HEADERS}. Re-initializing local file and DataFrame to empty to prevent corruption.")
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS)
preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False)
elif preferences_df.empty:
# Hub data is empty. Ensure DataFrame in memory has correct columns.
# And ensure the local CSV (which should have been written by load_preferences_from_hf_hub) has correct headers.
if list(preferences_df.columns) != config.CSV_HEADERS:
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) # Correct columns in memory
# Ensure the local file has correct headers if it's empty or its headers are wrong
needs_header_rewrite = True
if os.path.exists(config.RESULTS_CSV_FILE):
if os.path.getsize(config.RESULTS_CSV_FILE) == 0: # File is completely empty
needs_header_rewrite = True
else:
try:
local_headers = list(pd.read_csv(config.RESULTS_CSV_FILE, nrows=0).columns)
if local_headers == config.CSV_HEADERS:
needs_header_rewrite = False
except Exception: # Error reading headers, assume rewrite is needed
pass
if needs_header_rewrite:
print(f"Local file {config.RESULTS_CSV_FILE} (after Hub sync resulted in empty data) is empty or has incorrect headers. Writing/Re-writing headers.")
pd.DataFrame(columns=config.CSV_HEADERS).to_csv(config.RESULTS_CSV_FILE, index=False)
# Final safety net: ensure preferences_df is a DataFrame with correct columns.
if not isinstance(preferences_df, pd.DataFrame) or list(preferences_df.columns) != config.CSV_HEADERS:
print("Critical: preferences_df is not a valid DataFrame with correct headers after initialization. Resetting to empty with correct headers.")
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS)
# Ensure the CSV file reflects this state
preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False)
# Scan for available data
ALL_SAMPLES_BY_DOMAIN = utils.scan_data_directory(config.DATA_FOLDER)
if not ALL_SAMPLES_BY_DOMAIN:
print(f"CRITICAL: No data found in {config.DATA_FOLDER}. The app might not function correctly.")
# Potentially raise an error or display a message in the UI if no data
# --- Scheduler for Periodic Uploads ---
def scheduled_upload_job():
global preferences_df
print(f"Running scheduled job: Preparing to upload preferences from {config.RESULTS_CSV_FILE} at {datetime.now()}")
lock_path = config.RESULTS_CSV_FILE + ".lock"
with FileLock(lock_path):
print(f"Acquired lock for scheduled upload: {lock_path}")
if os.path.exists(config.RESULTS_CSV_FILE):
try:
# Read the current state of the CSV file for upload
df_to_upload = pd.read_csv(config.RESULTS_CSV_FILE)
if not df_to_upload.empty:
utils.save_preferences_to_hf_hub(
df_to_upload, # df_to_upload is passed for the empty check inside save_preferences_to_hf_hub
config.HF_DATASET_REPO_ID,
config.RESULTS_CSV_FILE, # This is the target filename on the Hub
commit_message="Periodic background update"
)
print(f"Scheduled job: Attempted upload of data from {config.RESULTS_CSV_FILE}.")
else:
print(f"Scheduled job: Local preferences file {config.RESULTS_CSV_FILE} is empty. Nothing to upload.")
except pd.errors.EmptyDataError:
print(f"Scheduled job: Local preferences file {config.RESULTS_CSV_FILE} is empty (read as EmptyDataError). Nothing to upload.")
except Exception as e:
print(f"Scheduled job: Error reading or uploading {config.RESULTS_CSV_FILE}: {e}")
else:
print(f"Scheduled job: Local preferences file {config.RESULTS_CSV_FILE} does not exist. Nothing to upload.")
print(f"Released lock for scheduled upload: {lock_path}")
scheduler = BackgroundScheduler()
scheduler.add_job(scheduled_upload_job, 'interval', hours=config.PUSH_INTERVAL_HOURS)
scheduler.start()
print(f"Scheduler started. Will attempt to upload preferences every {config.PUSH_INTERVAL_HOURS} hour(s).")
# --- Core Gradio App Functions ---
def start_new_session():
"""Initializes a new user session."""
session_id = utils.generate_session_id()
sample_queue = utils.prepare_session_samples(ALL_SAMPLES_BY_DOMAIN, config.SAMPLES_PER_DOMAIN)
current_sample_index = 0
if not sample_queue:
no_samples_msg = f"# 😥 No Samples Available!\n\n### Please check the data folder configuration or try again later."
return session_id, sample_queue, current_sample_index, no_samples_msg, None, None, None, [], [], True
print(f"New session started: {session_id}, with {len(sample_queue)} samples.")
domain_prompt_md, bg, fg, s_data, out_imgs, disp_info, end_flag = load_and_display_sample(sample_queue, current_sample_index)
return session_id, sample_queue, current_sample_index, domain_prompt_md, bg, fg, s_data, out_imgs, disp_info, end_flag
def load_and_display_sample(sample_queue, current_sample_index):
"""Loads and prepares a single sample for display."""
if not sample_queue or current_sample_index >= len(sample_queue):
end_session_msg = f"# 🎉 All Rated! 🎉\n\n### All samples for this session have been rated. Thank you!"
return end_session_msg, None, None, None, [], [], True # End of session
domain, sample_id = sample_queue[current_sample_index]
sample_data = utils.load_sample_data(domain, sample_id)
if sample_data is None:
print(f"Error loading sample {domain}/{sample_id}. Skipping.")
error_msg = f"## ⚠️ Error Loading Sample\n\nCould not load data for {domain}/{sample_id}. Skipping to the next one."
return error_msg, None, None, None, [], [], False
prompt_text = sample_data["prompt"]
bg_img_path = sample_data["background_img_path"]
fg_img_path = sample_data["foreground_img_path"]
# Load input bg/fg images without forcing them to be square
# The gr.Image component will handle scaling to the specified height while preserving aspect ratio.
bg_image_to_display = Image.open(bg_img_path)
fg_image_to_display = Image.open(fg_img_path)
output_model_keys = list(sample_data["output_image_paths"].keys())
random.shuffle(output_model_keys)
displayed_models_info = []
output_images_for_display = []
# square_size is still used for output option images
square_size = (config.IMAGE_DISPLAY_SIZE[0], config.IMAGE_DISPLAY_SIZE[0])
for model_key in output_model_keys:
img_path = sample_data["output_image_paths"][model_key]
try:
img = Image.open(img_path).resize(square_size) # Output images remain square
output_images_for_display.append(img)
displayed_models_info.append((model_key, img_path))
except FileNotFoundError:
print(f"Image not found: {img_path} for model {model_key}. Skipping this option.")
except Exception as e:
print(f"Error loading or resizing image {img_path}: {e}. Skipping this option.")
blank_image = Image.new('RGB', square_size, (200, 200, 200))
while len(output_images_for_display) < 4:
output_images_for_display.append(blank_image)
displayed_models_info.append(("BLANK_SLOT", "N/A"))
domain_prompt_markdown = f"""
### Domain: {domain}
## Prompt
### _"{prompt_text}"_
"""
return (
domain_prompt_markdown,
bg_image_to_display, # Pass the PIL image directly
fg_image_to_display, # Pass the PIL image directly
sample_data,
output_images_for_display[:4],
displayed_models_info[:4],
False
)
def process_vote(choice_index, session_id, sample_queue, current_sample_index, current_sample_data, displayed_models_info_for_sample):
global preferences_df
if current_sample_data is None or not displayed_models_info_for_sample or choice_index >= len(displayed_models_info_for_sample):
print("Error: Invalid data for processing vote. Skipping.")
current_sample_index += 1
if current_sample_index >= len(sample_queue):
error_end_msg = f"# ⚠️ Error Processing Vote ⚠️\n\n### An issue occurred. The session has ended."
return preferences_df, current_sample_index, error_end_msg, None, None, None, [], [], True
else:
next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide = load_and_display_sample(sample_queue, current_sample_index)
return preferences_df, current_sample_index, next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide
domain, sample_id = sample_queue[current_sample_index]
preferred_model_key, _ = displayed_models_info_for_sample[choice_index]
if preferred_model_key == "BLANK_SLOT":
print("User clicked on a blank slot. Vote not recorded. Please select a valid image.")
_prompt_md, _bg, _fg, _s_data, _out_imgs, _disp_info, _hide = load_and_display_sample(sample_queue, current_sample_index)
return preferences_df, current_sample_index, _prompt_md, _bg, _fg, _s_data, _out_imgs, _disp_info, _hide
print(f"Session {session_id}: Voted for model '{config.MODEL_DISPLAY_NAMES.get(preferred_model_key, preferred_model_key)}' (key: {preferred_model_key}) for sample {domain}/{sample_id}")
preferences_df = utils.record_preference(
df=preferences_df,
session_id=session_id,
domain=domain,
sample_id=sample_id,
prompt=current_sample_data["prompt"],
bg_path=current_sample_data["background_img_path"],
fg_path=current_sample_data["foreground_img_path"],
displayed_models_info=displayed_models_info_for_sample,
preferred_model_key=preferred_model_key
)
# Append the new preference to the CSV file
if not preferences_df.empty:
new_preference_df = preferences_df.iloc[-1:] # Get the last row as a new DataFrame
lock_path = config.RESULTS_CSV_FILE + ".lock"
with FileLock(lock_path):
print(f"Acquired lock for vote processing: {lock_path}")
try:
file_exists_and_has_content = os.path.exists(config.RESULTS_CSV_FILE) and os.path.getsize(config.RESULTS_CSV_FILE) > 0
new_preference_df.to_csv(
config.RESULTS_CSV_FILE,
mode='a',
header=not file_exists_and_has_content, # Write header if file is new or empty
index=False
)
print(f"Appended new preference to {config.RESULTS_CSV_FILE}")
except Exception as e:
print(f"Error appending preference to local CSV {config.RESULTS_CSV_FILE}: {e}")
finally:
print(f"Released lock for vote processing: {lock_path}")
else:
print("Warning: preferences_df is empty after utils.record_preference. Cannot append to CSV.")
# Removed full CSV overwrite:
# try:
# preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False)
# print(f"Preferences saved locally to {config.RESULTS_CSV_FILE}")
# except Exception as e:
# print(f"Error saving preferences locally: {e}")
current_sample_index += 1
if current_sample_index >= len(sample_queue):
# Removed session end upload:
# utils.save_preferences_to_hf_hub(preferences_df, config.HF_DATASET_REPO_ID, config.RESULTS_CSV_FILE, commit_message="Session end update")
final_msg = f"# 🎉 Session Complete! 🎉\n\n### All samples have been rated. Thank you for your participation!"
return preferences_df, current_sample_index, final_msg, None, None, None, [], [], True
next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide = load_and_display_sample(sample_queue, current_sample_index)
return preferences_df, current_sample_index, next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide
# --- Gradio UI Definition ---
custom_css = """
.custom-vote-button {
background-color: #FFA500 !important; /* Light Orange for normal state */
border-color: #FFA500 !important; /* Light Orange for normal state */
color: white !important;
}
.custom-vote-button:hover {
background-color: #FF8C00 !important; /* Dark Orange for hover state */
border-color: #FF8C00 !important; /* Dark Orange for hover state */
color: white !important;
}
"""
with gr.Blocks(title=config.APP_TITLE, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue), css=custom_css) as demo:
session_id_state = gr.State()
sample_queue_state = gr.State([])
current_sample_index_state = gr.State(0)
current_sample_data_state = gr.State()
displayed_models_info_state = gr.State([])
preferences_df_state = gr.State(value=preferences_df)
gr.Markdown(f"# {config.APP_TITLE}")
gr.Markdown(config.APP_DESCRIPTION)
with gr.Row():
start_button = gr.Button("Start New Session / Load First Sample", variant="primary")
with gr.Row(equal_height=False):
with gr.Column(scale=1):
domain_prompt_info_display = gr.Markdown(value="### Click 'Start New Session' to begin.")
with gr.Column(scale=2):
with gr.Row():
input_bg_image_display = gr.Image(label="Input Background", type="pil", height=config.IMAGE_DISPLAY_SIZE[0], interactive=False)
input_fg_image_display = gr.Image(label="Input Foreground", type="pil", height=config.IMAGE_DISPLAY_SIZE[0], interactive=False)
gr.Markdown("---")
gr.Markdown("## Choose your preferred composed image:")
output_image_displays = []
vote_buttons = []
with gr.Row():
for i in range(4):
with gr.Column():
img_display = gr.Image(label=f"Option {i+1}", type="pil", height=config.IMAGE_DISPLAY_SIZE[0], interactive=False)
output_image_displays.append(img_display)
vote_btn = gr.Button(f"Select Option {i+1}", elem_id=f"vote_btn_{i}", elem_classes=["custom-vote-button"])
vote_buttons.append(vote_btn)
end_of_session_msg_display = gr.Markdown("", visible=True)
def handle_start_session():
s_id, s_queue, s_idx, domain_prompt_or_end_msg, bg, fg, s_data, out_imgs, disp_info, end = start_new_session()
while len(out_imgs) < 4: out_imgs.append(None)
while len(disp_info) < 4: disp_info.append(("BLANK_SLOT", "N/A"))
updates = {
session_id_state: s_id,
sample_queue_state: s_queue,
current_sample_index_state: s_idx,
domain_prompt_info_display: domain_prompt_or_end_msg if not end else "",
input_bg_image_display: bg,
input_fg_image_display: fg,
current_sample_data_state: s_data,
displayed_models_info_state: disp_info,
end_of_session_msg_display: domain_prompt_or_end_msg if end else ""
}
for i in range(4):
updates[output_image_displays[i]] = out_imgs[i] if i < len(out_imgs) else None
num_actual_outputs = 0
if s_data and "output_image_paths" in s_data and s_data["output_image_paths"]:
num_actual_outputs = sum(1 for m_key, _ in disp_info if m_key != "BLANK_SLOT" and m_key is not None)
updates[vote_buttons[i]] = gr.Button(interactive=not end and i < num_actual_outputs)
return updates
start_button.click(
fn=handle_start_session,
inputs=[],
outputs=[
session_id_state, sample_queue_state, current_sample_index_state,
domain_prompt_info_display,
input_bg_image_display, input_fg_image_display,
current_sample_data_state, displayed_models_info_state, end_of_session_msg_display,
*output_image_displays, *vote_buttons
]
)
def make_vote_fn(choice_idx):
def vote_action(s_id, s_queue, s_idx, current_s_data, disp_info_for_sample, prefs_df_val):
global preferences_df
preferences_df = prefs_df_val
new_prefs_df, new_s_idx, domain_prompt_or_end_msg, bg, fg, new_s_data, out_imgs, new_disp_info, end = process_vote(
choice_idx, s_id, s_queue, s_idx, current_s_data, disp_info_for_sample
)
while len(out_imgs) < 4: out_imgs.append(None)
while len(new_disp_info) < 4: new_disp_info.append(("BLANK_SLOT", "N/A"))
updates = {
preferences_df_state: new_prefs_df,
current_sample_index_state: new_s_idx,
domain_prompt_info_display: domain_prompt_or_end_msg if not end else "",
input_bg_image_display: bg,
input_fg_image_display: fg,
current_sample_data_state: new_s_data,
displayed_models_info_state: new_disp_info,
end_of_session_msg_display: domain_prompt_or_end_msg if end else ""
}
for i in range(4):
updates[output_image_displays[i]] = out_imgs[i] if i < len(out_imgs) else None
num_actual_outputs = 0
if new_s_data and "output_image_paths" in new_s_data and new_s_data["output_image_paths"]:
num_actual_outputs = sum(1 for m_key, _ in new_disp_info if m_key != "BLANK_SLOT" and m_key is not None)
updates[vote_buttons[i]] = gr.Button(interactive=not end and i < num_actual_outputs)
return updates
return vote_action
for i, btn in enumerate(vote_buttons):
btn.click(
fn=make_vote_fn(i),
inputs=[
session_id_state, sample_queue_state, current_sample_index_state,
current_sample_data_state, displayed_models_info_state, preferences_df_state
],
outputs=[
preferences_df_state, current_sample_index_state,
domain_prompt_info_display,
input_bg_image_display, input_fg_image_display,
current_sample_data_state, displayed_models_info_state, end_of_session_msg_display,
*output_image_displays, *vote_buttons
]
)
gr.Markdown(config.FOOTER_MESSAGE)
if __name__ == "__main__":
if not os.path.exists(config.DATA_FOLDER):
print(f"Creating dummy data folder: {config.DATA_FOLDER}")
os.makedirs(config.DATA_FOLDER, exist_ok=True)
dummy_domains = ["Real-Cartoon", "Real-Painting"]
dummy_model_keys = list(config.MODEL_OUTPUT_IMAGE_NAMES.keys())
for domain in dummy_domains:
domain_path = os.path.join(config.DATA_FOLDER, domain)
os.makedirs(domain_path, exist_ok=True)
for i in range(config.SAMPLES_PER_DOMAIN + 2):
sample_id = f"sample_{i:03d}"
sample_path = os.path.join(domain_path, sample_id)
os.makedirs(sample_path, exist_ok=True)
with open(os.path.join(sample_path, config.PROMPT_FILE_NAME), "w") as f:
f.write(f"This is a dummy prompt for {domain} sample {sample_id}.")
colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255)]
try:
img_bg = Image.new('RGB', config.IMAGE_DISPLAY_SIZE, color='gray')
img_bg.save(os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME))
img_fg = Image.new('RGB', config.IMAGE_DISPLAY_SIZE, color='lightgray')
img_fg.save(os.path.join(sample_path, config.FOREGROUND_IMAGE_NAME))
for idx, model_key in enumerate(dummy_model_keys):
model_img_name = config.MODEL_OUTPUT_IMAGE_NAMES[model_key]
img_model = Image.new('RGB', config.IMAGE_DISPLAY_SIZE, color=colors[idx % len(colors)])
img_model.save(os.path.join(sample_path, model_img_name))
except Exception as e:
print(f"Error creating dummy image: {e}")
print("Dummy data creation complete.")
ALL_SAMPLES_BY_DOMAIN = utils.scan_data_directory(config.DATA_FOLDER)
demo.launch()
import atexit
atexit.register(lambda: scheduler.shutdown() if scheduler.running else None)