Spaces:
Running
Running
\ | |
import os | |
import random | |
import uuid | |
import pandas as pd | |
from datetime import datetime | |
from huggingface_hub import HfApi, hf_hub_download, login | |
from PIL import Image | |
import shutil | |
import config | |
# --- Hugging Face Hub Functions --- | |
def login_hugging_face(): | |
"""Logs in to Hugging Face Hub using token from config or environment variable.""" | |
token = config.HF_TOKEN or os.getenv("HF_HUB_TOKEN") | |
if token: | |
login(token=token) | |
print("Successfully logged into Hugging Face Hub.") | |
else: | |
print("HF_TOKEN not set in config and HF_HUB_TOKEN not in environment. Proceeding without login. Uploads to private repos will fail.") | |
def load_preferences_from_hf_hub(repo_id, filename): | |
"""Downloads the preferences CSV from the Hugging Face Hub dataset repo. | |
Returns a Pandas DataFrame or None if the file doesn't exist or on error. | |
""" | |
try: | |
downloaded_file_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
repo_type="dataset", | |
local_dir=".", # Download to current directory, ensure .gitignore if needed | |
local_dir_use_symlinks=False | |
) | |
# Move file to be just filename, hf_hub_download might place it in a subfolder | |
if os.path.dirname(downloaded_file_path) != os.path.abspath("."): | |
destination_path = os.path.join(".", os.path.basename(downloaded_file_path)) | |
shutil.move(downloaded_file_path, destination_path) | |
downloaded_file_path = destination_path | |
if os.path.exists(downloaded_file_path): | |
print(f"Successfully downloaded {filename} from {repo_id}") | |
df = pd.read_csv(downloaded_file_path) | |
# Ensure local copy is named as expected by config | |
if downloaded_file_path != filename: | |
os.rename(downloaded_file_path, filename) | |
return df | |
else: # Should not happen if download was successful | |
print(f"Downloaded file {downloaded_file_path} does not exist locally.") | |
return None | |
except Exception as e: | |
print(f"Could not download {filename} from {repo_id}. Error: {e}") | |
print("Starting with an empty preferences table or local copy if available.") | |
if os.path.exists(filename): | |
print(f"Loading local copy of {filename}") | |
return pd.read_csv(filename) | |
return None | |
def save_preferences_to_hf_hub(df, repo_id, filename, commit_message="Update preferences"): | |
"""Saves the DataFrame to a local CSV and uploads it to the Hugging Face Hub.""" | |
if df is None or df.empty: | |
print("Preferences DataFrame is empty. Nothing to save or upload.") | |
return | |
try: | |
df.to_csv(filename, index=False) | |
print(f"Preferences saved locally to {filename}") | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=filename, | |
path_in_repo=filename, | |
repo_id=repo_id, | |
repo_type="dataset", | |
commit_message=commit_message, | |
) | |
print(f"Successfully uploaded {filename} to {repo_id}") | |
except Exception as e: | |
print(f"Error saving or uploading {filename} to Hugging Face Hub: {e}") | |
print("Changes are saved locally. Will attempt upload on next scheduled push.") | |
# --- Data Loading and Sampling --- | |
def scan_data_directory(data_folder): | |
""" | |
Scans the data directory to find domains and their samples. | |
Returns a dictionary: {"domain_name": ["sample_id1", "sample_id2", ...]} | |
""" | |
all_samples_by_domain = {} | |
if not os.path.isdir(data_folder): | |
print(f"Error: Data folder '{data_folder}' not found.") | |
return all_samples_by_domain | |
for domain_name in os.listdir(data_folder): | |
domain_path = os.path.join(data_folder, domain_name) | |
if os.path.isdir(domain_path): | |
all_samples_by_domain[domain_name] = [] | |
for sample_id in os.listdir(domain_path): | |
sample_path = os.path.join(domain_path, sample_id) | |
# Basic check: ensure it's a directory and contains expected files (e.g., prompt) | |
prompt_file = os.path.join(sample_path, config.PROMPT_FILE_NAME) | |
bg_image = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME) | |
if os.path.isdir(sample_path) and os.path.exists(prompt_file) and os.path.exists(bg_image): | |
all_samples_by_domain[domain_name].append(sample_id) | |
if not all_samples_by_domain[domain_name]: | |
print(f"Warning: No valid samples found in domain '{domain_name}'.") | |
if not all_samples_by_domain: | |
print(f"Warning: No domains found or no valid samples in any domain in '{data_folder}'.") | |
return all_samples_by_domain | |
def prepare_session_samples(all_samples_by_domain, samples_per_domain): | |
""" | |
Prepares a list of (domain, sample_id) tuples for a user session. | |
Randomly selects 'samples_per_domain' from each domain. | |
The returned list is shuffled. | |
""" | |
session_queue = [] | |
for domain, samples in all_samples_by_domain.items(): | |
if samples: # only if there are samples in the domain | |
chosen_samples = random.sample(samples, min(len(samples), samples_per_domain)) | |
for sample_id in chosen_samples: | |
session_queue.append((domain, sample_id)) | |
random.shuffle(session_queue) | |
return session_queue | |
# --- Session and Data Handling --- | |
def generate_session_id(): | |
"""Generates a unique session ID.""" | |
return uuid.uuid4().hex[:config.SESSION_ID_LENGTH] | |
def load_sample_data(domain, sample_id): | |
""" | |
Loads data for a specific sample: prompt, input images, and output image paths. | |
Returns a dictionary or None if data is incomplete. | |
""" | |
sample_path = os.path.join(config.DATA_FOLDER, domain, sample_id) | |
prompt_path = os.path.join(sample_path, config.PROMPT_FILE_NAME) | |
bg_image_path = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME) | |
fg_image_path = os.path.join(sample_path, config.FOREGROUND_IMAGE_NAME) | |
if not all(os.path.exists(p) for p in [prompt_path, bg_image_path, fg_image_path]): | |
print(f"Error: Missing core files for sample {domain}/{sample_id}") | |
return None | |
try: | |
with open(prompt_path, 'r', encoding='utf-8') as f: | |
prompt_text = f.read().strip() | |
except Exception as e: | |
print(f"Error reading prompt for {domain}/{sample_id}: {e}") | |
return None | |
output_images = {} # {model_key: path_to_image} | |
for model_key, img_name in config.MODEL_OUTPUT_IMAGE_NAMES.items(): | |
img_path = os.path.join(sample_path, img_name) | |
if os.path.exists(img_path): | |
output_images[model_key] = img_path | |
else: | |
print(f"Warning: Missing output image {img_name} for model {model_key} in sample {domain}/{sample_id}") | |
# Decide if a sample is invalid if an output is missing, or if it can proceed | |
# For now, we'll allow it to proceed and it just won't show that option. | |
# A better approach might be to ensure all 4 are present during data prep. | |
if len(output_images) < len(config.MODEL_OUTPUT_IMAGE_NAMES): | |
print(f"Warning: Sample {domain}/{sample_id} is missing one or more model outputs. It will have fewer than 4 options.") | |
if not output_images: # No outputs at all | |
return None | |
return { | |
"prompt": prompt_text, | |
"background_img_path": bg_image_path, | |
"foreground_img_path": fg_image_path, | |
"output_image_paths": output_images # dict {model_key: path} | |
} | |
def record_preference(df, session_id, domain, sample_id, prompt, bg_path, fg_path, displayed_models_info, preferred_model_key): | |
""" | |
Appends a new preference record to the DataFrame. | |
displayed_models_info: list of (model_key, image_path) in the order they were displayed. | |
preferred_model_key: The key of the model the user selected (e.g., "model_a"). | |
""" | |
timestamp = datetime.now().isoformat() | |
# Create a dictionary for the new row | |
new_row = { | |
"session_id": session_id, | |
"timestamp": timestamp, | |
"domain": domain, | |
"sample_id": sample_id, | |
"prompt": prompt, | |
"input_background": os.path.basename(bg_path), # Storing just filename for brevity | |
"input_foreground": os.path.basename(fg_path), # Storing just filename for brevity | |
"preferred_model_key": preferred_model_key, | |
"preferred_model_filename": config.MODEL_OUTPUT_IMAGE_NAMES.get(preferred_model_key, "N/A") | |
} | |
# Add displayed order; ensure all columns exist even if fewer than 4 models were shown | |
for i in range(4): # Assuming max 4 display slots | |
col_name = f"displayed_order_model_{i+1}" | |
if i < len(displayed_models_info): | |
new_row[col_name] = displayed_models_info[i][0] # Store model_key | |
else: | |
new_row[col_name] = None # Or some placeholder like "EMPTY_SLOT" | |
new_df_row = pd.DataFrame([new_row], columns=config.CSV_HEADERS) | |
if df is None: | |
df = new_df_row | |
else: | |
df = pd.concat([df, new_df_row], ignore_index=True) | |
return df | |