Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import random | |
from PIL import Image | |
import torch | |
import gradio as gr | |
import dotenv | |
from adapter import load_ip_adapter_model, get_file_path | |
from example import EXAMPLES | |
dotenv.load_dotenv(".env.local") | |
ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID") | |
ADAPTER_MODEL_PATH = os.environ.get("ADAPTER_MODEL_PATH") | |
ADAPTER_CONFIG_PATH = os.environ.get("ADAPTER_CONFIG_PATH") | |
assert ADAPTER_REPO_ID is not None | |
assert ADAPTER_MODEL_PATH is not None | |
assert ADAPTER_CONFIG_PATH is not None | |
BASE_MODEL_REPO_ID = os.environ.get( | |
"BASE_MODEL_REPO_ID", "p1atdev/animagine-xl-4.0-bnb-nf4" | |
) | |
BASE_MODEL_PATH = os.environ.get( | |
"BASE_MODEL_PATH", "animagine-xl-4.0-opt.bnb_nf4.safetensors" | |
) | |
INITIAL_BATCH_SIZE = int(os.environ.get("INITIAL_BATCH_SIZE", 1)) | |
adapter_model_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_MODEL_PATH) | |
adapter_config_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_CONFIG_PATH) | |
base_model_path = get_file_path(BASE_MODEL_REPO_ID, BASE_MODEL_PATH) | |
model = load_ip_adapter_model( | |
model_path=base_model_path, | |
config_path=adapter_config_path, | |
adapter_path=adapter_model_path, | |
) | |
model.to("cuda:0") | |
def on_generate( | |
prompt: str, | |
negative_prompt: str, | |
image: Image.Image | None, | |
width: int, | |
height: int, | |
steps: int, | |
cfg_scale: float, | |
seed: int, | |
randomize_seed: bool = True, | |
num_images: int = 4, | |
): | |
if image is not None: | |
image = image.convert("RGB") | |
if randomize_seed: | |
seed = random.randint(0, 2147483647) | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
images = model.generate( | |
prompt=[prompt] * num_images, # batch size 4 | |
negative_prompt=negative_prompt, | |
reference_image=image, | |
num_inference_steps=steps, | |
cfg_scale=cfg_scale, | |
width=width, | |
height=height, | |
seed=seed, | |
do_offloading=False, | |
device="cuda:0", | |
max_token_length=225, | |
execution_dtype=torch.bfloat16, | |
) | |
torch.cuda.empty_cache() | |
return images, seed | |
def main(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.TextArea( | |
label="Prompt", | |
value="masterpiece, best quality", | |
placeholder="masterpiece, best quality", | |
interactive=True, | |
) | |
input_image = gr.Image( | |
label="Reference Image", | |
type="pil", | |
height=600, | |
) | |
with gr.Accordion("Negative Prompt", open=False): | |
negative_prompt = gr.TextArea( | |
label="Negative Prompt", | |
show_label=False, | |
value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry", | |
interactive=True, | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=2048, | |
step=128, | |
value=896, | |
interactive=True, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=2048, | |
step=128, | |
value=1152, | |
interactive=True, | |
) | |
with gr.Accordion("Advanced options", open=False): | |
num_images = gr.Slider( | |
label="Number of images to generate", | |
minimum=1, | |
maximum=8, | |
step=1, | |
value=INITIAL_BATCH_SIZE, | |
interactive=True, | |
) | |
with gr.Row(): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=2147483647, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox( | |
label="Randomize seed", | |
value=True, | |
interactive=True, | |
scale=1, | |
) | |
steps = gr.Slider( | |
label="Inference steps", | |
minimum=10, | |
maximum=50, | |
step=1, | |
value=25, | |
interactive=True, | |
) | |
cfg_scale = gr.Slider( | |
label="CFG scale", | |
minimum=3.0, | |
maximum=8.0, | |
step=0.5, | |
value=5.0, | |
interactive=True, | |
) | |
with gr.Column(): | |
generate_button = gr.Button( | |
"Generate", | |
variant="primary", | |
) | |
output_image = gr.Gallery( | |
label="Generated images", | |
type="pil", | |
rows=2, | |
height="768px", | |
preview=True, | |
show_label=True, | |
) | |
comment = gr.Markdown( | |
label="Comment", | |
visible=False, | |
) | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=[input_image, prompt, width, height, comment], | |
cache_examples=False, | |
) | |
gr.on( | |
triggers=[generate_button.click], | |
fn=on_generate, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
input_image, | |
width, | |
height, | |
steps, | |
cfg_scale, | |
seed, | |
randomize_seed, | |
num_images, | |
], | |
outputs=[output_image, seed], | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |