Spaces:
Running
Running
\ | |
import os | |
import random | |
import uuid | |
import pandas as pd | |
from datetime import datetime | |
from huggingface_hub import HfApi, hf_hub_download, login | |
from PIL import Image | |
import shutil | |
import config | |
# --- Hugging Face Hub Functions --- | |
def login_hugging_face(): | |
"""Logs in to Hugging Face Hub using token from config or environment variable.""" | |
token = config.HF_TOKEN or os.getenv("HF_HUB_TOKEN") | |
if token: | |
login(token=token) | |
print("Successfully logged into Hugging Face Hub.") | |
else: | |
print("HF_TOKEN not set in config and HF_HUB_TOKEN not in environment. Proceeding without login. Uploads to private repos will fail.") | |
def load_preferences_from_hf_hub(repo_id, filename): | |
"""Downloads the preferences CSV from the Hugging Face Hub dataset repo. | |
Overwrites the local file specified by `filename` with the downloaded content. | |
Returns a Pandas DataFrame loaded from the (potentially overwritten) local file. | |
Returns None if the file doesn't exist on the Hub and the local file also doesn't exist. | |
Returns an empty DataFrame with correct headers if the Hub file is empty or if errors occur during download | |
and the local file is also problematic (e.g., empty or wrong headers). | |
""" | |
local_file_path = filename # The target local file path | |
download_successful = False | |
hub_file_exists = True | |
try: | |
print(f"Attempting to download {filename} from {repo_id} to {local_file_path}") | |
# hf_hub_download will download to a cache and return its path. | |
# We want to ensure our target local_file_path is the one used. | |
downloaded_cache_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, # This is path_in_repo | |
repo_type="dataset", | |
local_dir=os.path.dirname(local_file_path) or ".", # Ensure download into the correct directory | |
local_dir_use_symlinks=False, | |
# force_filename=os.path.basename(local_file_path) # Ensure the final name is correct | |
) | |
# After download, hf_hub_download might place it in a nested structure based on the repo. | |
# We need to ensure it is moved to the exact `local_file_path` if it's not already there. | |
# The `downloaded_cache_path` is often like `../hub/datasets--repo--id/snapshots/hash/filename` | |
# or directly `filename` if `local_dir` was specific enough and `force_filename` worked as expected. | |
# To be safe, explicitly move from where hf_hub_download put it to our desired local_file_path. | |
# Ensure target directory exists | |
target_dir = os.path.dirname(local_file_path) | |
if target_dir and not os.path.exists(target_dir): | |
os.makedirs(target_dir) | |
# Overwrite local_file_path with the downloaded file | |
shutil.move(downloaded_cache_path, local_file_path) | |
print(f"Successfully downloaded and moved {filename} from {repo_id} to {local_file_path}") | |
download_successful = True | |
except Exception as e: # Broadly catch hf_hub_download errors (e.g., file not found, network issues) | |
if "404" in str(e) or "does not exist" in str(e).lower(): # More specific check for file not found | |
print(f"File {filename} not found on Hugging Face Hub repository {repo_id}.") | |
hub_file_exists = False | |
# If Hub file doesn't exist, we might want to delete any existing local file | |
# to ensure we start fresh or rely on a truly empty state if no local file exists. | |
if os.path.exists(local_file_path): | |
print(f"Hub file {filename} not found. Deleting existing local file {local_file_path} to ensure clean state.") | |
# Before deleting, consider if we should back it up or if this is the desired behavior. | |
# For now, let's assume we want to reflect the Hub's state (i.e., no file). | |
# However, the app.py logic expects an empty DataFrame with headers if the Hub is empty. | |
# So, instead of deleting, we will ensure an empty CSV with headers is created later. | |
pass # Handled by logic below: if download failed, local file is checked. | |
else: | |
print(f"Could not download {filename} from {repo_id}. Error: {e}") | |
# Download failed, proceed to load/check local file or create empty. | |
# After attempting download (successful or not), manage the local file and load it. | |
if download_successful: | |
# File was downloaded and moved to local_file_path. Load it. | |
try: | |
df = pd.read_csv(local_file_path) | |
if list(df.columns) != config.CSV_HEADERS: | |
print(f"Warning: Downloaded file {local_file_path} has incorrect headers. Re-initializing as empty with correct headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) # Overwrite with empty + headers | |
elif df.empty: | |
# Check if the file itself had incorrect headers or was truly empty | |
current_headers = [] | |
if os.path.getsize(local_file_path) > 0: | |
try: | |
current_headers = list(pd.read_csv(local_file_path, nrows=0).columns) | |
except Exception: | |
pass | |
if current_headers != config.CSV_HEADERS: | |
print(f"Downloaded file {local_file_path} is empty but has incorrect/no headers. Re-initializing with correct headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) | |
else: # Empty dataframe, but headers in file are correct | |
df = pd.DataFrame(columns=config.CSV_HEADERS) # Ensure in-memory df also has columns | |
return df | |
except pd.errors.EmptyDataError: | |
print(f"Downloaded file {local_file_path} is empty. Initializing DataFrame with headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) # Ensure empty file has headers | |
return df | |
except Exception as e: | |
print(f"Error reading downloaded file {local_file_path}: {e}. Returning empty DataFrame with headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) # Ensure file has headers | |
return df | |
else: # Download was not successful (Hub file not found or other error) | |
if not hub_file_exists: | |
# Hub file does not exist. We should ensure the local file is also effectively empty (with headers). | |
print(f"Hub file {filename} does not exist. Ensuring local file {local_file_path} is empty with correct headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) # Create/overwrite local as empty with headers | |
return df | |
else: # Other download error, but Hub file might exist. Try loading local as fallback. | |
print(f"Download of {filename} failed. Attempting to load from local file {local_file_path}.") | |
if os.path.exists(local_file_path): | |
try: | |
df = pd.read_csv(local_file_path) | |
if list(df.columns) != config.CSV_HEADERS: | |
print(f"Warning: Local file {local_file_path} (fallback) has incorrect headers. Re-initializing as empty with correct headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) | |
elif df.empty: | |
current_headers = [] | |
if os.path.getsize(local_file_path) > 0: | |
try: | |
current_headers = list(pd.read_csv(local_file_path, nrows=0).columns) | |
except Exception: | |
pass | |
if current_headers != config.CSV_HEADERS: | |
print(f"Local file {local_file_path} (fallback) is empty but has incorrect/no headers. Re-initializing with correct headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) | |
else: # Empty dataframe, but headers in file are correct | |
df = pd.DataFrame(columns=config.CSV_HEADERS) # Ensure in-memory df also has columns | |
return df | |
except pd.errors.EmptyDataError: | |
print(f"Local file {local_file_path} (fallback) is empty. Initializing DataFrame with headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) | |
return df | |
except Exception as e: | |
print(f"Error reading local file {local_file_path} (fallback): {e}. Returning empty DataFrame with headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) | |
return df | |
else: | |
# Download failed, Hub file might exist but couldn't be fetched, local file also doesn't exist. | |
print(f"Download of {filename} failed, and local file {local_file_path} not found. Initializing empty DataFrame with headers.") | |
df = pd.DataFrame(columns=config.CSV_HEADERS) | |
df.to_csv(local_file_path, index=False) # Create new local empty file with headers | |
return df | |
def save_preferences_to_hf_hub(df, repo_id, filename, commit_message="Update preferences"): | |
"""Saves the DataFrame to a local CSV and uploads it to the Hugging Face Hub.""" | |
if df is None or df.empty: | |
print("Preferences DataFrame (passed for checking) is empty. Nothing to upload based on this check.") | |
# However, the primary source for upload should be the file itself if it exists and has content. | |
# This check is more of a guard based on the state when the scheduler decided to run. | |
# Let's ensure we check the file on disk if df is empty. | |
if not (os.path.exists(filename) and os.path.getsize(filename) > 0): | |
print(f"Local file {filename} is also non-existent or empty. Nothing to upload.") | |
return | |
print(f"Passed DataFrame was empty, but local file {filename} exists and has content. Proceeding with upload of the file.") | |
try: | |
# CRITICAL CHANGE: Removed df.to_csv(filename, index=False) | |
# The local CSV (specified by `filename`) is now the direct source of truth for uploading. | |
# It is appended to by process_vote and periodically read by the scheduler. | |
# This function should only be responsible for uploading that file. | |
print(f"Attempting to upload existing file: {filename} to {repo_id}") | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=filename, | |
path_in_repo=filename, | |
repo_id=repo_id, | |
repo_type="dataset", | |
commit_message=commit_message, | |
) | |
print(f"Successfully uploaded {filename} to {repo_id}") | |
except Exception as e: | |
print(f"Error saving or uploading {filename} to Hugging Face Hub: {e}") | |
print("Changes are saved locally. Will attempt upload on next scheduled push.") | |
# --- Data Loading and Sampling --- | |
def scan_data_directory(data_folder): | |
""" | |
Scans the data directory to find domains and their samples. | |
Returns a dictionary: {"domain_name": ["sample_id1", "sample_id2", ...]} | |
""" | |
all_samples_by_domain = {} | |
if not os.path.isdir(data_folder): | |
print(f"Error: Data folder '{data_folder}' not found.") | |
return all_samples_by_domain | |
for domain_name in os.listdir(data_folder): | |
domain_path = os.path.join(data_folder, domain_name) | |
if os.path.isdir(domain_path): | |
all_samples_by_domain[domain_name] = [] | |
for sample_id in os.listdir(domain_path): | |
sample_path = os.path.join(domain_path, sample_id) | |
# Basic check: ensure it's a directory and contains expected files (e.g., prompt) | |
prompt_file = os.path.join(sample_path, config.PROMPT_FILE_NAME) | |
bg_image = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME) | |
if os.path.isdir(sample_path) and os.path.exists(prompt_file) and os.path.exists(bg_image): | |
all_samples_by_domain[domain_name].append(sample_id) | |
if not all_samples_by_domain[domain_name]: | |
print(f"Warning: No valid samples found in domain '{domain_name}'.") | |
if not all_samples_by_domain: | |
print(f"Warning: No domains found or no valid samples in any domain in '{data_folder}'.") | |
return all_samples_by_domain | |
def prepare_session_samples(all_samples_by_domain, samples_per_domain): | |
""" | |
Prepares a list of (domain, sample_id) tuples for a user session. | |
Randomly selects 'samples_per_domain' from each domain. | |
The returned list is shuffled. | |
""" | |
session_queue = [] | |
for domain, samples in all_samples_by_domain.items(): | |
if samples: # only if there are samples in the domain | |
chosen_samples = random.sample(samples, min(len(samples), samples_per_domain)) | |
for sample_id in chosen_samples: | |
session_queue.append((domain, sample_id)) | |
random.shuffle(session_queue) | |
return session_queue | |
# --- Session and Data Handling --- | |
def generate_session_id(): | |
"""Generates a unique session ID.""" | |
return uuid.uuid4().hex[:config.SESSION_ID_LENGTH] | |
def load_sample_data(domain, sample_id): | |
""" | |
Loads data for a specific sample: prompt, input images, and output image paths. | |
Returns a dictionary or None if data is incomplete. | |
""" | |
sample_path = os.path.join(config.DATA_FOLDER, domain, sample_id) | |
prompt_path = os.path.join(sample_path, config.PROMPT_FILE_NAME) | |
bg_image_path = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME) | |
fg_image_path = os.path.join(sample_path, config.FOREGROUND_IMAGE_NAME) | |
if not all(os.path.exists(p) for p in [prompt_path, bg_image_path, fg_image_path]): | |
print(f"Error: Missing core files for sample {domain}/{sample_id}") | |
return None | |
try: | |
with open(prompt_path, 'r', encoding='utf-8') as f: | |
prompt_text = f.read().strip() | |
except Exception as e: | |
print(f"Error reading prompt for {domain}/{sample_id}: {e}") | |
return None | |
output_images = {} # {model_key: path_to_image} | |
for model_key, img_name in config.MODEL_OUTPUT_IMAGE_NAMES.items(): | |
img_path = os.path.join(sample_path, img_name) | |
if os.path.exists(img_path): | |
output_images[model_key] = img_path | |
else: | |
print(f"Warning: Missing output image {img_name} for model {model_key} in sample {domain}/{sample_id}") | |
# Decide if a sample is invalid if an output is missing, or if it can proceed | |
# For now, we'll allow it to proceed and it just won't show that option. | |
# A better approach might be to ensure all 4 are present during data prep. | |
if len(output_images) < len(config.MODEL_OUTPUT_IMAGE_NAMES): | |
print(f"Warning: Sample {domain}/{sample_id} is missing one or more model outputs. It will have fewer than 4 options.") | |
if not output_images: # No outputs at all | |
return None | |
return { | |
"prompt": prompt_text, | |
"background_img_path": bg_image_path, | |
"foreground_img_path": fg_image_path, | |
"output_image_paths": output_images # dict {model_key: path} | |
} | |
def record_preference(df, session_id, domain, sample_id, prompt, bg_path, fg_path, displayed_models_info, preferred_model_key): | |
""" | |
Appends a new preference record to the DataFrame. | |
displayed_models_info: list of (model_key, image_path) in the order they were displayed. | |
preferred_model_key: The key of the model the user selected (e.g., "model_a"). | |
""" | |
timestamp = datetime.now().isoformat() | |
# Create a dictionary for the new row | |
new_row = { | |
"session_id": session_id, | |
"timestamp": timestamp, | |
"domain": domain, | |
"sample_id": sample_id, | |
"prompt": prompt, | |
"input_background": os.path.basename(bg_path), # Storing just filename for brevity | |
"input_foreground": os.path.basename(fg_path), # Storing just filename for brevity | |
"preferred_model_key": preferred_model_key, | |
"preferred_model_filename": config.MODEL_OUTPUT_IMAGE_NAMES.get(preferred_model_key, "N/A") | |
} | |
# Add displayed order; ensure all columns exist even if fewer than 4 models were shown | |
for i in range(4): # Assuming max 4 display slots | |
col_name = f"displayed_order_model_{i+1}" | |
if i < len(displayed_models_info): | |
new_row[col_name] = displayed_models_info[i][0] # Store model_key | |
else: | |
new_row[col_name] = None # Or some placeholder like "EMPTY_SLOT" | |
new_df_row = pd.DataFrame([new_row], columns=config.CSV_HEADERS) | |
if df is None: | |
df = new_df_row | |
else: | |
df = pd.concat([df, new_df_row], ignore_index=True) | |
return df | |