macrdel
update app.py
a2c9c1b
raw
history blame
6.28 kB
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
from peft import PeftModel, PeftConfig
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# Model list including your LoRA model
MODEL_LIST = [
"CompVis/stable-diffusion-v1-4",
"stabilityai/sdxl-turbo",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1",
"macrdel/unico_proj",
]
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# Cache to avoid re-initializing pipelines repeatedly
model_cache = {}
def load_pipeline(model_id: str):
"""
Loads or retrieves a cached DiffusionPipeline.
If the chosen model is your LoRA adapter, then load the base model
(CompVis/stable-diffusion-v1-4) and apply the LoRA weights.
"""
if model_id in model_cache:
return model_cache[model_id]
if model_id == "macrdel/unico_proj":
# Use the specified base model for your LoRA adapter.
base_model = "CompVis/stable-diffusion-v1-4"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
# Load the LoRA weights
pipe.unet = PeftModel.from_pretrained(
pipe.unet,
model_id,
subfolder="unet",
torch_dtype=torch_dtype
)
pipe.text_encoder = PeftModel.from_pretrained(
pipe.text_encoder,
model_id,
subfolder="text_encoder",
torch_dtype=torch_dtype
)
else:
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
pipe.to(device)
model_cache[model_id] = pipe
return pipe
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def infer(
model_id,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_scale, # New parameter for adjusting LoRA scale
progress=gr.Progress(track_tqdm=True),
):
# Load the pipeline for the chosen model
pipe = load_pipeline(model_id)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
# If using the LoRA model, update the LoRA scale if supported.
if model_id == "macrdel/unico_proj":
# This assumes your pipeline's unet has a method to update the LoRA scale.
if hasattr(pipe.unet, "set_lora_scale"):
pipe.unet.set_lora_scale(lora_scale)
else:
print("Warning: LoRA scale adjustment method not found on UNet.")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image, seed
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image Gradio Template")
with gr.Row():
# Dropdown to select the model from Hugging Face
model_id = gr.Dropdown(
label="Model",
choices=MODEL_LIST,
value=MODEL_LIST[0], # Default model
)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42, # Default seed
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=20.0,
step=0.5,
value=7.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=20,
)
# New slider for LoRA scale.
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
info="Adjust the influence of the LoRA weights",
)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
model_id,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_scale, # Pass the new slider value
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()