1Prompt1Story / resource /gen_benchmark.py
byliutao's picture
Upload folder using huggingface_hub
31c1396 verified
import argparse
import os
import yaml
from main import generate_images, load_unet_controller
from unet import utils
import torch
import queue
import threading
from tqdm import tqdm # Import tqdm
def main_ben(unet_controller, pipe, save_dir, id_prompt, frame_prompt_list, seed, window_length):
unet_controller.ipca_index = -1
unet_controller.ipca_time_step = -1
# Ensure each process uses its own assigned device
os.makedirs(save_dir, exist_ok=True)
images, story_image = generate_images(unet_controller, pipe, id_prompt, frame_prompt_list, save_dir, window_length, seed, verbose=False)
return images, story_image
def process_instance(unet_controller, pipe, instance):
# Unpack instance and execute task
save_dir, id_prompt, frame_prompt_list, seed, window_length = instance
return main_ben(unet_controller, pipe, save_dir, id_prompt, frame_prompt_list, seed, window_length)
def worker(device, unet_controller, pipe, task_queue, pbar):
# Process tasks until queue is empty
while not task_queue.empty():
instance = task_queue.get()
if instance is None: # If None is encountered, stop the worker
break
# Process the instance
result = process_instance(unet_controller, pipe, instance)
# Log the completion
print(f"Finished processing {instance[1]}") # Log the processed instance (id_prompt)
task_queue.task_done() # Mark the task as done
pbar.update(1) # Update the progress bar
def main():
parser = argparse.ArgumentParser(description="Calculate image similarities using DreamSim or CLIP.")
parser.add_argument('--device', type=str, choices=['cuda:0', 'cuda:1', 'cuda'], default='cuda')
parser.add_argument('--save_dir', type=str,)
parser.add_argument('--benchmark_path', type=str,)
parser.add_argument('--model_path', type=str, default='stabilityai/stable-diffusion-xl-base-1.0', help='Path to the model')
parser.add_argument('--precision', type=str, choices=["fp16", "fp32"], default="fp16", help='Model precision')
parser.add_argument('--window_length', type=int, default=10, help='Window length for story generation')
parser.add_argument('--num_gpus', type=int, default=2, help='Number of GPUs to use')
parser.add_argument('--fix_seed', type=int, default=42, help='-1 for random seed')
args = parser.parse_args()
# Create a list of devices
devices = [f'cuda:{i}' for i in range(args.num_gpus)] # List of device names
if args.num_gpus == 1:
devices = [args.device]
# Load unet_controllers and pipes for each device
unet_controllers = {}
pipes = {}
for device in devices:
pipe, _ = utils.load_pipe_from_path(args.model_path, device, torch.float16 if args.precision == "fp16" else torch.float32, args.precision)
unet_controller = load_unet_controller(pipe, device)
unet_controller.Save_story_image = False
unet_controller.Prompt_embeds_mode = "svr-eot"
# unet_controller.Is_freeu_enabled = True
unet_controllers[device] = unet_controller
pipes[device] = pipe
# Load the benchmark data
with open(os.path.expanduser(args.benchmark_path), 'r') as file:
data = yaml.safe_load(file)
instances = []
for subject_domain, subject_domain_instances in data.items():
for index, instance in enumerate(subject_domain_instances):
id_prompt = f'{instance["style"]} {instance["subject"]}'
frame_prompt_list = instance["settings"]
save_dir = os.path.join(args.save_dir, f"{subject_domain}_{index}")
if args.fix_seed != -1:
seed = args.fix_seed
else:
import random
seed = random.randint(0, 2**32 - 1)
instances.append((save_dir, id_prompt, frame_prompt_list, seed, args.window_length))
# Create a task queue and populate it with instances
task_queue = queue.Queue()
for instance in instances:
task_queue.put(instance)
# Initialize tqdm progress bar
pbar = tqdm(total=len(instances))
# Create threads for each device to process instances
threads = []
for device in devices:
unet_controller = unet_controllers[device]
pipe = pipes[device]
thread = threading.Thread(target=worker, args=(device, unet_controller, pipe, task_queue, pbar))
threads.append(thread)
thread.start()
import time
time.sleep(1) # Wait for 1 second before starting the next thread
# Wait for all threads to finish
for thread in threads:
thread.join()
# Close the progress bar
pbar.close()
if __name__ == "__main__":
main()