Lifeinhockey's picture
Update app.py
4fafed4 verified
raw
history blame
8.72 kB
import os
import gradio as gr
import numpy as np
import random
import torch
from diffusers import StableDiffusionPipeline
from peft import PeftModel, LoraConfig
def get_lora_sd_pipeline(
ckpt_dir='./lora_man_animestyle',
base_model_name_or_path=None,
dtype=torch.float16,
adapter_name="default"
):
unet_sub_dir = os.path.join(ckpt_dir, "unet")
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
base_model_name_or_path = config.base_model_name_or_path
if base_model_name_or_path is None:
raise ValueError("Please specify the base model name or path")
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
before_params = pipe.unet.parameters()
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
pipe.unet.set_adapter(adapter_name)
after_params = pipe.unet.parameters()
print("Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
if os.path.exists(text_encoder_sub_dir):
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
if dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
return pipe
def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
chunks = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
with torch.no_grad():
embeds = [text_encoder(chunk.to(text_encoder.device))[0] for chunk in chunks]
return torch.cat(embeds, dim=1)
def align_embeddings(prompt_embeds, negative_prompt_embeds):
max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
pipe_default = get_lora_sd_pipeline(ckpt_dir='./lora_man_animestyle', base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def infer(
prompt,
negative_prompt,
width=512,
height=512,
num_inference_steps=20,
model_id="stable-diffusion-v1-5/stable-diffusion-v1-5",
seed=42,
guidance_scale=7.5,
lora_scale=0.5,
progress=gr.Progress(track_tqdm=True)
):
generator = torch.Generator(device).manual_seed(seed)
if model_id != model_id_default:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
else:
pipe = pipe_default
prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
print(f"LoRA scale applied: {lora_scale}")
pipe.fuse_lora(lora_scale=lora_scale)
params = {
'prompt_embeds': prompt_embeds,
'negative_prompt_embeds': negative_prompt_embeds,
'guidance_scale': guidance_scale,
'num_inference_steps': num_inference_steps,
'width': width,
'height': height,
'generator': generator,
}
return pipe(**params).images[0]
examples = [
"Young man in anime style. The image is of high sharpness and resolution. A handsome, thoughtful man. The man is depicted in the foreground, close-up or middle plan. The background is blurry, not sharp. The play of light and shadow is visible on the face and clothes."
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
"An astronaut riding a green horse.",
"A delicious ceviche cheesecake slice.",
"A futuristic sports car is located on the surface of Mars. Stars, planets, mountains and craters are visible.",
]
examples_negative = [
"blurred details, low resolution, poor image of a man's face, poor quality, artifacts, black and white image"
"blurry details, low resolution, poorly defined edges",
"bad face, bad quality, artifacts, low-res, black and white",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
available_models = [
"stable-diffusion-v1-5/stable-diffusion-v1-5",
"SG161222/Realistic_Vision_V3.0_VAE",
"CompVis/stable-diffusion-v1-4",
"stabilityai/sdxl-turbo",
"runwayml/stable-diffusion-v1-5",
"sd-legacy/stable-diffusion-v1-5",
"prompthero/openjourney",
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/stable-diffusion-3.5-large",
"stabilityai/stable-diffusion-3.5-large-turbo",
]
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image Gradio Template from V. Gorsky")
with gr.Row():
model_id = gr.Dropdown(
label="Model Selection",
choices=available_models,
value="stable-diffusion-v1-5/stable-diffusion-v1-5",
interactive=True
)
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
)
with gr.Row():
lora_scale = gr.Slider(
label="LoRA scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.5,
)
with gr.Row():
seed = gr.Number(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.5, # Replace with defaults that work for your model
)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=30, # Replace with defaults that work for your model
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
gr.Examples(examples=examples, inputs=[prompt])
gr.Examples(examples=examples_negative, inputs=[negative_prompt])
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
model_id,
prompt,
negative_prompt,
seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_scale,
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()