Spaces:
Running
Running
File size: 9,252 Bytes
af5e0d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
\
import os
import random
import uuid
import pandas as pd
from datetime import datetime
from huggingface_hub import HfApi, hf_hub_download, login
from PIL import Image
import shutil
import config
# --- Hugging Face Hub Functions ---
def login_hugging_face():
"""Logs in to Hugging Face Hub using token from config or environment variable."""
token = config.HF_TOKEN or os.getenv("HF_HUB_TOKEN")
if token:
login(token=token)
print("Successfully logged into Hugging Face Hub.")
else:
print("HF_TOKEN not set in config and HF_HUB_TOKEN not in environment. Proceeding without login. Uploads to private repos will fail.")
def load_preferences_from_hf_hub(repo_id, filename):
"""Downloads the preferences CSV from the Hugging Face Hub dataset repo.
Returns a Pandas DataFrame or None if the file doesn't exist or on error.
"""
try:
downloaded_file_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="dataset",
local_dir=".", # Download to current directory, ensure .gitignore if needed
local_dir_use_symlinks=False
)
# Move file to be just filename, hf_hub_download might place it in a subfolder
if os.path.dirname(downloaded_file_path) != os.path.abspath("."):
destination_path = os.path.join(".", os.path.basename(downloaded_file_path))
shutil.move(downloaded_file_path, destination_path)
downloaded_file_path = destination_path
if os.path.exists(downloaded_file_path):
print(f"Successfully downloaded {filename} from {repo_id}")
df = pd.read_csv(downloaded_file_path)
# Ensure local copy is named as expected by config
if downloaded_file_path != filename:
os.rename(downloaded_file_path, filename)
return df
else: # Should not happen if download was successful
print(f"Downloaded file {downloaded_file_path} does not exist locally.")
return None
except Exception as e:
print(f"Could not download {filename} from {repo_id}. Error: {e}")
print("Starting with an empty preferences table or local copy if available.")
if os.path.exists(filename):
print(f"Loading local copy of {filename}")
return pd.read_csv(filename)
return None
def save_preferences_to_hf_hub(df, repo_id, filename, commit_message="Update preferences"):
"""Saves the DataFrame to a local CSV and uploads it to the Hugging Face Hub."""
if df is None or df.empty:
print("Preferences DataFrame is empty. Nothing to save or upload.")
return
try:
df.to_csv(filename, index=False)
print(f"Preferences saved locally to {filename}")
api = HfApi()
api.upload_file(
path_or_fileobj=filename,
path_in_repo=filename,
repo_id=repo_id,
repo_type="dataset",
commit_message=commit_message,
)
print(f"Successfully uploaded {filename} to {repo_id}")
except Exception as e:
print(f"Error saving or uploading {filename} to Hugging Face Hub: {e}")
print("Changes are saved locally. Will attempt upload on next scheduled push.")
# --- Data Loading and Sampling ---
def scan_data_directory(data_folder):
"""
Scans the data directory to find domains and their samples.
Returns a dictionary: {"domain_name": ["sample_id1", "sample_id2", ...]}
"""
all_samples_by_domain = {}
if not os.path.isdir(data_folder):
print(f"Error: Data folder '{data_folder}' not found.")
return all_samples_by_domain
for domain_name in os.listdir(data_folder):
domain_path = os.path.join(data_folder, domain_name)
if os.path.isdir(domain_path):
all_samples_by_domain[domain_name] = []
for sample_id in os.listdir(domain_path):
sample_path = os.path.join(domain_path, sample_id)
# Basic check: ensure it's a directory and contains expected files (e.g., prompt)
prompt_file = os.path.join(sample_path, config.PROMPT_FILE_NAME)
bg_image = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME)
if os.path.isdir(sample_path) and os.path.exists(prompt_file) and os.path.exists(bg_image):
all_samples_by_domain[domain_name].append(sample_id)
if not all_samples_by_domain[domain_name]:
print(f"Warning: No valid samples found in domain '{domain_name}'.")
if not all_samples_by_domain:
print(f"Warning: No domains found or no valid samples in any domain in '{data_folder}'.")
return all_samples_by_domain
def prepare_session_samples(all_samples_by_domain, samples_per_domain):
"""
Prepares a list of (domain, sample_id) tuples for a user session.
Randomly selects 'samples_per_domain' from each domain.
The returned list is shuffled.
"""
session_queue = []
for domain, samples in all_samples_by_domain.items():
if samples: # only if there are samples in the domain
chosen_samples = random.sample(samples, min(len(samples), samples_per_domain))
for sample_id in chosen_samples:
session_queue.append((domain, sample_id))
random.shuffle(session_queue)
return session_queue
# --- Session and Data Handling ---
def generate_session_id():
"""Generates a unique session ID."""
return uuid.uuid4().hex[:config.SESSION_ID_LENGTH]
def load_sample_data(domain, sample_id):
"""
Loads data for a specific sample: prompt, input images, and output image paths.
Returns a dictionary or None if data is incomplete.
"""
sample_path = os.path.join(config.DATA_FOLDER, domain, sample_id)
prompt_path = os.path.join(sample_path, config.PROMPT_FILE_NAME)
bg_image_path = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME)
fg_image_path = os.path.join(sample_path, config.FOREGROUND_IMAGE_NAME)
if not all(os.path.exists(p) for p in [prompt_path, bg_image_path, fg_image_path]):
print(f"Error: Missing core files for sample {domain}/{sample_id}")
return None
try:
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt_text = f.read().strip()
except Exception as e:
print(f"Error reading prompt for {domain}/{sample_id}: {e}")
return None
output_images = {} # {model_key: path_to_image}
for model_key, img_name in config.MODEL_OUTPUT_IMAGE_NAMES.items():
img_path = os.path.join(sample_path, img_name)
if os.path.exists(img_path):
output_images[model_key] = img_path
else:
print(f"Warning: Missing output image {img_name} for model {model_key} in sample {domain}/{sample_id}")
# Decide if a sample is invalid if an output is missing, or if it can proceed
# For now, we'll allow it to proceed and it just won't show that option.
# A better approach might be to ensure all 4 are present during data prep.
if len(output_images) < len(config.MODEL_OUTPUT_IMAGE_NAMES):
print(f"Warning: Sample {domain}/{sample_id} is missing one or more model outputs. It will have fewer than 4 options.")
if not output_images: # No outputs at all
return None
return {
"prompt": prompt_text,
"background_img_path": bg_image_path,
"foreground_img_path": fg_image_path,
"output_image_paths": output_images # dict {model_key: path}
}
def record_preference(df, session_id, domain, sample_id, prompt, bg_path, fg_path, displayed_models_info, preferred_model_key):
"""
Appends a new preference record to the DataFrame.
displayed_models_info: list of (model_key, image_path) in the order they were displayed.
preferred_model_key: The key of the model the user selected (e.g., "model_a").
"""
timestamp = datetime.now().isoformat()
# Create a dictionary for the new row
new_row = {
"session_id": session_id,
"timestamp": timestamp,
"domain": domain,
"sample_id": sample_id,
"prompt": prompt,
"input_background": os.path.basename(bg_path), # Storing just filename for brevity
"input_foreground": os.path.basename(fg_path), # Storing just filename for brevity
"preferred_model_key": preferred_model_key,
"preferred_model_filename": config.MODEL_OUTPUT_IMAGE_NAMES.get(preferred_model_key, "N/A")
}
# Add displayed order; ensure all columns exist even if fewer than 4 models were shown
for i in range(4): # Assuming max 4 display slots
col_name = f"displayed_order_model_{i+1}"
if i < len(displayed_models_info):
new_row[col_name] = displayed_models_info[i][0] # Store model_key
else:
new_row[col_name] = None # Or some placeholder like "EMPTY_SLOT"
new_df_row = pd.DataFrame([new_row], columns=config.CSV_HEADERS)
if df is None:
df = new_df_row
else:
df = pd.concat([df, new_df_row], ignore_index=True)
return df
|