matsant01's picture
Major update of code. Adding new data with our generations
af5e0d4
raw
history blame
9.25 kB
\
import os
import random
import uuid
import pandas as pd
from datetime import datetime
from huggingface_hub import HfApi, hf_hub_download, login
from PIL import Image
import shutil
import config
# --- Hugging Face Hub Functions ---
def login_hugging_face():
"""Logs in to Hugging Face Hub using token from config or environment variable."""
token = config.HF_TOKEN or os.getenv("HF_HUB_TOKEN")
if token:
login(token=token)
print("Successfully logged into Hugging Face Hub.")
else:
print("HF_TOKEN not set in config and HF_HUB_TOKEN not in environment. Proceeding without login. Uploads to private repos will fail.")
def load_preferences_from_hf_hub(repo_id, filename):
"""Downloads the preferences CSV from the Hugging Face Hub dataset repo.
Returns a Pandas DataFrame or None if the file doesn't exist or on error.
"""
try:
downloaded_file_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="dataset",
local_dir=".", # Download to current directory, ensure .gitignore if needed
local_dir_use_symlinks=False
)
# Move file to be just filename, hf_hub_download might place it in a subfolder
if os.path.dirname(downloaded_file_path) != os.path.abspath("."):
destination_path = os.path.join(".", os.path.basename(downloaded_file_path))
shutil.move(downloaded_file_path, destination_path)
downloaded_file_path = destination_path
if os.path.exists(downloaded_file_path):
print(f"Successfully downloaded {filename} from {repo_id}")
df = pd.read_csv(downloaded_file_path)
# Ensure local copy is named as expected by config
if downloaded_file_path != filename:
os.rename(downloaded_file_path, filename)
return df
else: # Should not happen if download was successful
print(f"Downloaded file {downloaded_file_path} does not exist locally.")
return None
except Exception as e:
print(f"Could not download {filename} from {repo_id}. Error: {e}")
print("Starting with an empty preferences table or local copy if available.")
if os.path.exists(filename):
print(f"Loading local copy of {filename}")
return pd.read_csv(filename)
return None
def save_preferences_to_hf_hub(df, repo_id, filename, commit_message="Update preferences"):
"""Saves the DataFrame to a local CSV and uploads it to the Hugging Face Hub."""
if df is None or df.empty:
print("Preferences DataFrame is empty. Nothing to save or upload.")
return
try:
df.to_csv(filename, index=False)
print(f"Preferences saved locally to {filename}")
api = HfApi()
api.upload_file(
path_or_fileobj=filename,
path_in_repo=filename,
repo_id=repo_id,
repo_type="dataset",
commit_message=commit_message,
)
print(f"Successfully uploaded {filename} to {repo_id}")
except Exception as e:
print(f"Error saving or uploading {filename} to Hugging Face Hub: {e}")
print("Changes are saved locally. Will attempt upload on next scheduled push.")
# --- Data Loading and Sampling ---
def scan_data_directory(data_folder):
"""
Scans the data directory to find domains and their samples.
Returns a dictionary: {"domain_name": ["sample_id1", "sample_id2", ...]}
"""
all_samples_by_domain = {}
if not os.path.isdir(data_folder):
print(f"Error: Data folder '{data_folder}' not found.")
return all_samples_by_domain
for domain_name in os.listdir(data_folder):
domain_path = os.path.join(data_folder, domain_name)
if os.path.isdir(domain_path):
all_samples_by_domain[domain_name] = []
for sample_id in os.listdir(domain_path):
sample_path = os.path.join(domain_path, sample_id)
# Basic check: ensure it's a directory and contains expected files (e.g., prompt)
prompt_file = os.path.join(sample_path, config.PROMPT_FILE_NAME)
bg_image = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME)
if os.path.isdir(sample_path) and os.path.exists(prompt_file) and os.path.exists(bg_image):
all_samples_by_domain[domain_name].append(sample_id)
if not all_samples_by_domain[domain_name]:
print(f"Warning: No valid samples found in domain '{domain_name}'.")
if not all_samples_by_domain:
print(f"Warning: No domains found or no valid samples in any domain in '{data_folder}'.")
return all_samples_by_domain
def prepare_session_samples(all_samples_by_domain, samples_per_domain):
"""
Prepares a list of (domain, sample_id) tuples for a user session.
Randomly selects 'samples_per_domain' from each domain.
The returned list is shuffled.
"""
session_queue = []
for domain, samples in all_samples_by_domain.items():
if samples: # only if there are samples in the domain
chosen_samples = random.sample(samples, min(len(samples), samples_per_domain))
for sample_id in chosen_samples:
session_queue.append((domain, sample_id))
random.shuffle(session_queue)
return session_queue
# --- Session and Data Handling ---
def generate_session_id():
"""Generates a unique session ID."""
return uuid.uuid4().hex[:config.SESSION_ID_LENGTH]
def load_sample_data(domain, sample_id):
"""
Loads data for a specific sample: prompt, input images, and output image paths.
Returns a dictionary or None if data is incomplete.
"""
sample_path = os.path.join(config.DATA_FOLDER, domain, sample_id)
prompt_path = os.path.join(sample_path, config.PROMPT_FILE_NAME)
bg_image_path = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME)
fg_image_path = os.path.join(sample_path, config.FOREGROUND_IMAGE_NAME)
if not all(os.path.exists(p) for p in [prompt_path, bg_image_path, fg_image_path]):
print(f"Error: Missing core files for sample {domain}/{sample_id}")
return None
try:
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt_text = f.read().strip()
except Exception as e:
print(f"Error reading prompt for {domain}/{sample_id}: {e}")
return None
output_images = {} # {model_key: path_to_image}
for model_key, img_name in config.MODEL_OUTPUT_IMAGE_NAMES.items():
img_path = os.path.join(sample_path, img_name)
if os.path.exists(img_path):
output_images[model_key] = img_path
else:
print(f"Warning: Missing output image {img_name} for model {model_key} in sample {domain}/{sample_id}")
# Decide if a sample is invalid if an output is missing, or if it can proceed
# For now, we'll allow it to proceed and it just won't show that option.
# A better approach might be to ensure all 4 are present during data prep.
if len(output_images) < len(config.MODEL_OUTPUT_IMAGE_NAMES):
print(f"Warning: Sample {domain}/{sample_id} is missing one or more model outputs. It will have fewer than 4 options.")
if not output_images: # No outputs at all
return None
return {
"prompt": prompt_text,
"background_img_path": bg_image_path,
"foreground_img_path": fg_image_path,
"output_image_paths": output_images # dict {model_key: path}
}
def record_preference(df, session_id, domain, sample_id, prompt, bg_path, fg_path, displayed_models_info, preferred_model_key):
"""
Appends a new preference record to the DataFrame.
displayed_models_info: list of (model_key, image_path) in the order they were displayed.
preferred_model_key: The key of the model the user selected (e.g., "model_a").
"""
timestamp = datetime.now().isoformat()
# Create a dictionary for the new row
new_row = {
"session_id": session_id,
"timestamp": timestamp,
"domain": domain,
"sample_id": sample_id,
"prompt": prompt,
"input_background": os.path.basename(bg_path), # Storing just filename for brevity
"input_foreground": os.path.basename(fg_path), # Storing just filename for brevity
"preferred_model_key": preferred_model_key,
"preferred_model_filename": config.MODEL_OUTPUT_IMAGE_NAMES.get(preferred_model_key, "N/A")
}
# Add displayed order; ensure all columns exist even if fewer than 4 models were shown
for i in range(4): # Assuming max 4 display slots
col_name = f"displayed_order_model_{i+1}"
if i < len(displayed_models_info):
new_row[col_name] = displayed_models_info[i][0] # Store model_key
else:
new_row[col_name] = None # Or some placeholder like "EMPTY_SLOT"
new_df_row = pd.DataFrame([new_row], columns=config.CSV_HEADERS)
if df is None:
df = new_df_row
else:
df = pd.concat([df, new_df_row], ignore_index=True)
return df