import gradio as gr import pandas as pd import os import random from datetime import datetime from apscheduler.schedulers.background import BackgroundScheduler from PIL import Image from filelock import FileLock # Added for file locking import config import utils # --- Global Variables & Initial Setup --- # Attempt to log in to Hugging Face Hub at startup utils.login_hugging_face() print(f"Attempting to load preferences from Hugging Face Hub, ensuring local {config.RESULTS_CSV_FILE} is synchronized.") # We assume utils.load_preferences_from_hf_hub: # 1. Downloads from Hub, overwrites local config.RESULTS_CSV_FILE. # 2. If Hub file doesn't exist, local config.RESULTS_CSV_FILE becomes empty (or reflects this). # 3. Returns the DataFrame loaded from the (now synchronized) local file. # 4. Returns None on major failure (e.g. network, file not found on Hub). preferences_df = utils.load_preferences_from_hf_hub(config.HF_DATASET_REPO_ID, config.RESULTS_CSV_FILE) if preferences_df is None: print(f"Failed to load from Hub or Hub is empty/file not found. Initializing/loading from {config.RESULTS_CSV_FILE} as a fallback.") if os.path.exists(config.RESULTS_CSV_FILE): try: preferences_df = pd.read_csv(config.RESULTS_CSV_FILE) if not preferences_df.empty and list(preferences_df.columns) != config.CSV_HEADERS: print(f"Warning: Local CSV {config.RESULTS_CSV_FILE} columns ({list(preferences_df.columns)}) do not match expected headers ({config.CSV_HEADERS}). Re-initializing file and DataFrame.") preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False) elif preferences_df.empty: # Loaded an empty DataFrame # Check if the file itself had incorrect headers or was truly empty current_headers = [] if os.path.getsize(config.RESULTS_CSV_FILE) > 0: try: current_headers = list(pd.read_csv(config.RESULTS_CSV_FILE, nrows=0).columns) except Exception: # Handle cases where reading headers might fail pass # Will be caught by re-initialization if headers are bad if current_headers != config.CSV_HEADERS: print(f"Local CSV {config.RESULTS_CSV_FILE} is empty or has incorrect headers. Re-initializing file and DataFrame with correct headers.") preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False) else: # Empty dataframe, but headers in file are correct preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) # Ensure in-memory df also has columns except pd.errors.EmptyDataError: print(f"Local CSV {config.RESULTS_CSV_FILE} is empty. Initializing file and DataFrame with headers.") preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False) except Exception as e: print(f"Error loading local {config.RESULTS_CSV_FILE}: {e}. Initializing file and DataFrame.") preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False) else: print(f"Local CSV {config.RESULTS_CSV_FILE} not found. Initializing file and DataFrame.") preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False) else: # Successfully loaded from Hub; local file config.RESULTS_CSV_FILE should be synchronized. print(f"Successfully loaded preferences from Hugging Face Hub. Local copy at {config.RESULTS_CSV_FILE} should be up-to-date.") if not preferences_df.empty and list(preferences_df.columns) != config.CSV_HEADERS: print(f"CRITICAL: Data from Hub has incorrect columns {list(preferences_df.columns)}. Expected {config.CSV_HEADERS}. Re-initializing local file and DataFrame to empty to prevent corruption.") preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False) elif preferences_df.empty: # Hub data is empty. Ensure DataFrame in memory has correct columns. # And ensure the local CSV (which should have been written by load_preferences_from_hf_hub) has correct headers. if list(preferences_df.columns) != config.CSV_HEADERS: preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) # Correct columns in memory # Ensure the local file has correct headers if it's empty or its headers are wrong needs_header_rewrite = True if os.path.exists(config.RESULTS_CSV_FILE): if os.path.getsize(config.RESULTS_CSV_FILE) == 0: # File is completely empty needs_header_rewrite = True else: try: local_headers = list(pd.read_csv(config.RESULTS_CSV_FILE, nrows=0).columns) if local_headers == config.CSV_HEADERS: needs_header_rewrite = False except Exception: # Error reading headers, assume rewrite is needed pass if needs_header_rewrite: print(f"Local file {config.RESULTS_CSV_FILE} (after Hub sync resulted in empty data) is empty or has incorrect headers. Writing/Re-writing headers.") pd.DataFrame(columns=config.CSV_HEADERS).to_csv(config.RESULTS_CSV_FILE, index=False) # Final safety net: ensure preferences_df is a DataFrame with correct columns. if not isinstance(preferences_df, pd.DataFrame) or list(preferences_df.columns) != config.CSV_HEADERS: print("Critical: preferences_df is not a valid DataFrame with correct headers after initialization. Resetting to empty with correct headers.") preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) # Ensure the CSV file reflects this state preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False) # Scan for available data ALL_SAMPLES_BY_DOMAIN = utils.scan_data_directory(config.DATA_FOLDER) if not ALL_SAMPLES_BY_DOMAIN: print(f"CRITICAL: No data found in {config.DATA_FOLDER}. The app might not function correctly.") # Potentially raise an error or display a message in the UI if no data # --- Scheduler for Periodic Uploads --- def scheduled_upload_job(): global preferences_df print(f"Running scheduled job: Preparing to upload preferences from {config.RESULTS_CSV_FILE} at {datetime.now()}") lock_path = config.RESULTS_CSV_FILE + ".lock" with FileLock(lock_path): print(f"Acquired lock for scheduled upload: {lock_path}") if os.path.exists(config.RESULTS_CSV_FILE): try: # Read the current state of the CSV file for upload df_to_upload = pd.read_csv(config.RESULTS_CSV_FILE) if not df_to_upload.empty: utils.save_preferences_to_hf_hub( df_to_upload, # df_to_upload is passed for the empty check inside save_preferences_to_hf_hub config.HF_DATASET_REPO_ID, config.RESULTS_CSV_FILE, # This is the target filename on the Hub commit_message="Periodic background update" ) print(f"Scheduled job: Attempted upload of data from {config.RESULTS_CSV_FILE}.") else: print(f"Scheduled job: Local preferences file {config.RESULTS_CSV_FILE} is empty. Nothing to upload.") except pd.errors.EmptyDataError: print(f"Scheduled job: Local preferences file {config.RESULTS_CSV_FILE} is empty (read as EmptyDataError). Nothing to upload.") except Exception as e: print(f"Scheduled job: Error reading or uploading {config.RESULTS_CSV_FILE}: {e}") else: print(f"Scheduled job: Local preferences file {config.RESULTS_CSV_FILE} does not exist. Nothing to upload.") print(f"Released lock for scheduled upload: {lock_path}") scheduler = BackgroundScheduler() scheduler.add_job(scheduled_upload_job, 'interval', hours=config.PUSH_INTERVAL_HOURS) scheduler.start() print(f"Scheduler started. Will attempt to upload preferences every {config.PUSH_INTERVAL_HOURS} hour(s).") # --- Core Gradio App Functions --- def start_new_session(): """Initializes a new user session.""" session_id = utils.generate_session_id() sample_queue = utils.prepare_session_samples(ALL_SAMPLES_BY_DOMAIN, config.SAMPLES_PER_DOMAIN) current_sample_index = 0 if not sample_queue: no_samples_msg = f"# 😥 No Samples Available!\n\n### Please check the data folder configuration or try again later." return session_id, sample_queue, current_sample_index, no_samples_msg, None, None, None, [], [], True print(f"New session started: {session_id}, with {len(sample_queue)} samples.") domain_prompt_md, bg, fg, s_data, out_imgs, disp_info, end_flag = load_and_display_sample(sample_queue, current_sample_index) return session_id, sample_queue, current_sample_index, domain_prompt_md, bg, fg, s_data, out_imgs, disp_info, end_flag def load_and_display_sample(sample_queue, current_sample_index): """Loads and prepares a single sample for display.""" if not sample_queue or current_sample_index >= len(sample_queue): end_session_msg = f"# 🎉 All Rated! 🎉\n\n### All samples for this session have been rated. Thank you!" return end_session_msg, None, None, None, [], [], True # End of session domain, sample_id = sample_queue[current_sample_index] sample_data = utils.load_sample_data(domain, sample_id) if sample_data is None: print(f"Error loading sample {domain}/{sample_id}. Skipping.") error_msg = f"## ⚠️ Error Loading Sample\n\nCould not load data for {domain}/{sample_id}. Skipping to the next one." return error_msg, None, None, None, [], [], False prompt_text = sample_data["prompt"] bg_img_path = sample_data["background_img_path"] fg_img_path = sample_data["foreground_img_path"] # Load input bg/fg images without forcing them to be square # The gr.Image component will handle scaling to the specified height while preserving aspect ratio. bg_image_to_display = Image.open(bg_img_path) fg_image_to_display = Image.open(fg_img_path) output_model_keys = list(sample_data["output_image_paths"].keys()) random.shuffle(output_model_keys) displayed_models_info = [] output_images_for_display = [] # square_size is still used for output option images square_size = (config.IMAGE_DISPLAY_SIZE[0], config.IMAGE_DISPLAY_SIZE[0]) for model_key in output_model_keys: img_path = sample_data["output_image_paths"][model_key] try: img = Image.open(img_path).resize(square_size) # Output images remain square output_images_for_display.append(img) displayed_models_info.append((model_key, img_path)) except FileNotFoundError: print(f"Image not found: {img_path} for model {model_key}. Skipping this option.") except Exception as e: print(f"Error loading or resizing image {img_path}: {e}. Skipping this option.") blank_image = Image.new('RGB', square_size, (200, 200, 200)) while len(output_images_for_display) < 4: output_images_for_display.append(blank_image) displayed_models_info.append(("BLANK_SLOT", "N/A")) domain_prompt_markdown = f"""
### Domain: {domain}

## Prompt ### _"{prompt_text}"_
""" return ( domain_prompt_markdown, bg_image_to_display, # Pass the PIL image directly fg_image_to_display, # Pass the PIL image directly sample_data, output_images_for_display[:4], displayed_models_info[:4], False ) def process_vote(choice_index, session_id, sample_queue, current_sample_index, current_sample_data, displayed_models_info_for_sample): global preferences_df if current_sample_data is None or not displayed_models_info_for_sample or choice_index >= len(displayed_models_info_for_sample): print("Error: Invalid data for processing vote. Skipping.") current_sample_index += 1 if current_sample_index >= len(sample_queue): error_end_msg = f"# ⚠️ Error Processing Vote ⚠️\n\n### An issue occurred. The session has ended." return preferences_df, current_sample_index, error_end_msg, None, None, None, [], [], True else: next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide = load_and_display_sample(sample_queue, current_sample_index) return preferences_df, current_sample_index, next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide domain, sample_id = sample_queue[current_sample_index] preferred_model_key, _ = displayed_models_info_for_sample[choice_index] if preferred_model_key == "BLANK_SLOT": print("User clicked on a blank slot. Vote not recorded. Please select a valid image.") _prompt_md, _bg, _fg, _s_data, _out_imgs, _disp_info, _hide = load_and_display_sample(sample_queue, current_sample_index) return preferences_df, current_sample_index, _prompt_md, _bg, _fg, _s_data, _out_imgs, _disp_info, _hide print(f"Session {session_id}: Voted for model '{config.MODEL_DISPLAY_NAMES.get(preferred_model_key, preferred_model_key)}' (key: {preferred_model_key}) for sample {domain}/{sample_id}") preferences_df = utils.record_preference( df=preferences_df, session_id=session_id, domain=domain, sample_id=sample_id, prompt=current_sample_data["prompt"], bg_path=current_sample_data["background_img_path"], fg_path=current_sample_data["foreground_img_path"], displayed_models_info=displayed_models_info_for_sample, preferred_model_key=preferred_model_key ) # Append the new preference to the CSV file if not preferences_df.empty: new_preference_df = preferences_df.iloc[-1:] # Get the last row as a new DataFrame lock_path = config.RESULTS_CSV_FILE + ".lock" with FileLock(lock_path): print(f"Acquired lock for vote processing: {lock_path}") try: file_exists_and_has_content = os.path.exists(config.RESULTS_CSV_FILE) and os.path.getsize(config.RESULTS_CSV_FILE) > 0 new_preference_df.to_csv( config.RESULTS_CSV_FILE, mode='a', header=not file_exists_and_has_content, # Write header if file is new or empty index=False ) print(f"Appended new preference to {config.RESULTS_CSV_FILE}") except Exception as e: print(f"Error appending preference to local CSV {config.RESULTS_CSV_FILE}: {e}") finally: print(f"Released lock for vote processing: {lock_path}") else: print("Warning: preferences_df is empty after utils.record_preference. Cannot append to CSV.") # Removed full CSV overwrite: # try: # preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False) # print(f"Preferences saved locally to {config.RESULTS_CSV_FILE}") # except Exception as e: # print(f"Error saving preferences locally: {e}") current_sample_index += 1 if current_sample_index >= len(sample_queue): # Removed session end upload: # utils.save_preferences_to_hf_hub(preferences_df, config.HF_DATASET_REPO_ID, config.RESULTS_CSV_FILE, commit_message="Session end update") final_msg = f"# 🎉 Session Complete! 🎉\n\n### All samples have been rated. Thank you for your participation!" return preferences_df, current_sample_index, final_msg, None, None, None, [], [], True next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide = load_and_display_sample(sample_queue, current_sample_index) return preferences_df, current_sample_index, next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide # --- Gradio UI Definition --- custom_css = """ .custom-vote-button { background-color: #FFA500 !important; /* Light Orange for normal state */ border-color: #FFA500 !important; /* Light Orange for normal state */ color: white !important; } .custom-vote-button:hover { background-color: #FF8C00 !important; /* Dark Orange for hover state */ border-color: #FF8C00 !important; /* Dark Orange for hover state */ color: white !important; } """ with gr.Blocks(title=config.APP_TITLE, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue), css=custom_css) as demo: session_id_state = gr.State() sample_queue_state = gr.State([]) current_sample_index_state = gr.State(0) current_sample_data_state = gr.State() displayed_models_info_state = gr.State([]) preferences_df_state = gr.State(value=preferences_df) gr.Markdown(f"# {config.APP_TITLE}") gr.Markdown(config.APP_DESCRIPTION) with gr.Row(): start_button = gr.Button("Start New Session / Load First Sample", variant="primary") with gr.Row(equal_height=False): with gr.Column(scale=1): domain_prompt_info_display = gr.Markdown(value="### Click 'Start New Session' to begin.") with gr.Column(scale=2): with gr.Row(): input_bg_image_display = gr.Image(label="Input Background", type="pil", height=config.IMAGE_DISPLAY_SIZE[0], interactive=False) input_fg_image_display = gr.Image(label="Input Foreground", type="pil", height=config.IMAGE_DISPLAY_SIZE[0], interactive=False) gr.Markdown("---") gr.Markdown("## Choose your preferred composed image:") output_image_displays = [] vote_buttons = [] with gr.Row(): for i in range(4): with gr.Column(): img_display = gr.Image(label=f"Option {i+1}", type="pil", height=config.IMAGE_DISPLAY_SIZE[0], interactive=False) output_image_displays.append(img_display) vote_btn = gr.Button(f"Select Option {i+1}", elem_id=f"vote_btn_{i}", elem_classes=["custom-vote-button"]) vote_buttons.append(vote_btn) end_of_session_msg_display = gr.Markdown("", visible=True) def handle_start_session(): s_id, s_queue, s_idx, domain_prompt_or_end_msg, bg, fg, s_data, out_imgs, disp_info, end = start_new_session() while len(out_imgs) < 4: out_imgs.append(None) while len(disp_info) < 4: disp_info.append(("BLANK_SLOT", "N/A")) updates = { session_id_state: s_id, sample_queue_state: s_queue, current_sample_index_state: s_idx, domain_prompt_info_display: domain_prompt_or_end_msg if not end else "", input_bg_image_display: bg, input_fg_image_display: fg, current_sample_data_state: s_data, displayed_models_info_state: disp_info, end_of_session_msg_display: domain_prompt_or_end_msg if end else "" } for i in range(4): updates[output_image_displays[i]] = out_imgs[i] if i < len(out_imgs) else None num_actual_outputs = 0 if s_data and "output_image_paths" in s_data and s_data["output_image_paths"]: num_actual_outputs = sum(1 for m_key, _ in disp_info if m_key != "BLANK_SLOT" and m_key is not None) updates[vote_buttons[i]] = gr.Button(interactive=not end and i < num_actual_outputs) return updates start_button.click( fn=handle_start_session, inputs=[], outputs=[ session_id_state, sample_queue_state, current_sample_index_state, domain_prompt_info_display, input_bg_image_display, input_fg_image_display, current_sample_data_state, displayed_models_info_state, end_of_session_msg_display, *output_image_displays, *vote_buttons ] ) def make_vote_fn(choice_idx): def vote_action(s_id, s_queue, s_idx, current_s_data, disp_info_for_sample, prefs_df_val): global preferences_df preferences_df = prefs_df_val new_prefs_df, new_s_idx, domain_prompt_or_end_msg, bg, fg, new_s_data, out_imgs, new_disp_info, end = process_vote( choice_idx, s_id, s_queue, s_idx, current_s_data, disp_info_for_sample ) while len(out_imgs) < 4: out_imgs.append(None) while len(new_disp_info) < 4: new_disp_info.append(("BLANK_SLOT", "N/A")) updates = { preferences_df_state: new_prefs_df, current_sample_index_state: new_s_idx, domain_prompt_info_display: domain_prompt_or_end_msg if not end else "", input_bg_image_display: bg, input_fg_image_display: fg, current_sample_data_state: new_s_data, displayed_models_info_state: new_disp_info, end_of_session_msg_display: domain_prompt_or_end_msg if end else "" } for i in range(4): updates[output_image_displays[i]] = out_imgs[i] if i < len(out_imgs) else None num_actual_outputs = 0 if new_s_data and "output_image_paths" in new_s_data and new_s_data["output_image_paths"]: num_actual_outputs = sum(1 for m_key, _ in new_disp_info if m_key != "BLANK_SLOT" and m_key is not None) updates[vote_buttons[i]] = gr.Button(interactive=not end and i < num_actual_outputs) return updates return vote_action for i, btn in enumerate(vote_buttons): btn.click( fn=make_vote_fn(i), inputs=[ session_id_state, sample_queue_state, current_sample_index_state, current_sample_data_state, displayed_models_info_state, preferences_df_state ], outputs=[ preferences_df_state, current_sample_index_state, domain_prompt_info_display, input_bg_image_display, input_fg_image_display, current_sample_data_state, displayed_models_info_state, end_of_session_msg_display, *output_image_displays, *vote_buttons ] ) gr.Markdown(config.FOOTER_MESSAGE) if __name__ == "__main__": if not os.path.exists(config.DATA_FOLDER): print(f"Creating dummy data folder: {config.DATA_FOLDER}") os.makedirs(config.DATA_FOLDER, exist_ok=True) dummy_domains = ["Real-Cartoon", "Real-Painting"] dummy_model_keys = list(config.MODEL_OUTPUT_IMAGE_NAMES.keys()) for domain in dummy_domains: domain_path = os.path.join(config.DATA_FOLDER, domain) os.makedirs(domain_path, exist_ok=True) for i in range(config.SAMPLES_PER_DOMAIN + 2): sample_id = f"sample_{i:03d}" sample_path = os.path.join(domain_path, sample_id) os.makedirs(sample_path, exist_ok=True) with open(os.path.join(sample_path, config.PROMPT_FILE_NAME), "w") as f: f.write(f"This is a dummy prompt for {domain} sample {sample_id}.") colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255)] try: img_bg = Image.new('RGB', config.IMAGE_DISPLAY_SIZE, color='gray') img_bg.save(os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME)) img_fg = Image.new('RGB', config.IMAGE_DISPLAY_SIZE, color='lightgray') img_fg.save(os.path.join(sample_path, config.FOREGROUND_IMAGE_NAME)) for idx, model_key in enumerate(dummy_model_keys): model_img_name = config.MODEL_OUTPUT_IMAGE_NAMES[model_key] img_model = Image.new('RGB', config.IMAGE_DISPLAY_SIZE, color=colors[idx % len(colors)]) img_model.save(os.path.join(sample_path, model_img_name)) except Exception as e: print(f"Error creating dummy image: {e}") print("Dummy data creation complete.") ALL_SAMPLES_BY_DOMAIN = utils.scan_data_directory(config.DATA_FOLDER) demo.launch() import atexit atexit.register(lambda: scheduler.shutdown() if scheduler.running else None)