Plat
init
84f584e
raw
history blame
6.51 kB
import spaces
import os
import random
from PIL import Image
import torch
import gradio as gr
import dotenv
from adapter import load_ip_adapter_model, get_file_path
from example import EXAMPLES
dotenv.load_dotenv(".env.local")
ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID")
ADAPTER_MODEL_PATH = os.environ.get("ADAPTER_MODEL_PATH")
ADAPTER_CONFIG_PATH = os.environ.get("ADAPTER_CONFIG_PATH")
assert ADAPTER_REPO_ID is not None
assert ADAPTER_MODEL_PATH is not None
assert ADAPTER_CONFIG_PATH is not None
BASE_MODEL_REPO_ID = os.environ.get(
"BASE_MODEL_REPO_ID", "p1atdev/animagine-xl-4.0-bnb-nf4"
)
BASE_MODEL_PATH = os.environ.get(
"BASE_MODEL_PATH", "animagine-xl-4.0-opt.bnb_nf4.safetensors"
)
INITIAL_BATCH_SIZE = int(os.environ.get("INITIAL_BATCH_SIZE", 1))
adapter_model_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_MODEL_PATH)
adapter_config_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_CONFIG_PATH)
base_model_path = get_file_path(BASE_MODEL_REPO_ID, BASE_MODEL_PATH)
model = load_ip_adapter_model(
model_path=base_model_path,
config_path=adapter_config_path,
adapter_path=adapter_model_path,
)
model.to("cuda:0")
@spaces.GPU
def on_generate(
prompt: str,
negative_prompt: str,
image: Image.Image | None,
width: int,
height: int,
steps: int,
cfg_scale: float,
seed: int,
randomize_seed: bool = True,
num_images: int = 4,
):
if image is not None:
image = image.convert("RGB")
if randomize_seed:
seed = random.randint(0, 2147483647)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
images = model.generate(
prompt=[prompt] * num_images, # batch size 4
negative_prompt=negative_prompt,
reference_image=image,
num_inference_steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
seed=seed,
do_offloading=False,
device="cuda:0",
max_token_length=225,
execution_dtype=torch.bfloat16,
)
torch.cuda.empty_cache()
return images, seed
def main():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.TextArea(
label="Prompt",
value="masterpiece, best quality",
placeholder="masterpiece, best quality",
interactive=True,
)
input_image = gr.Image(
label="Reference Image",
type="pil",
height=600,
)
with gr.Accordion("Negative Prompt", open=False):
negative_prompt = gr.TextArea(
label="Negative Prompt",
show_label=False,
value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
interactive=True,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=2048,
step=128,
value=896,
interactive=True,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=2048,
step=128,
value=1152,
interactive=True,
)
with gr.Accordion("Advanced options", open=False):
num_images = gr.Slider(
label="Number of images to generate",
minimum=1,
maximum=8,
step=1,
value=INITIAL_BATCH_SIZE,
interactive=True,
)
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True,
interactive=True,
scale=1,
)
steps = gr.Slider(
label="Inference steps",
minimum=10,
maximum=50,
step=1,
value=25,
interactive=True,
)
cfg_scale = gr.Slider(
label="CFG scale",
minimum=3.0,
maximum=8.0,
step=0.5,
value=5.0,
interactive=True,
)
with gr.Column():
generate_button = gr.Button(
"Generate",
variant="primary",
)
output_image = gr.Gallery(
label="Generated images",
type="pil",
rows=2,
height="768px",
preview=True,
show_label=True,
)
comment = gr.Markdown(
label="Comment",
visible=False,
)
gr.Examples(
examples=EXAMPLES,
inputs=[input_image, prompt, width, height, comment],
cache_examples=False,
)
gr.on(
triggers=[generate_button.click],
fn=on_generate,
inputs=[
prompt,
negative_prompt,
input_image,
width,
height,
steps,
cfg_scale,
seed,
randomize_seed,
num_images,
],
outputs=[output_image, seed],
)
demo.launch()
if __name__ == "__main__":
main()