CHATS / app.py
Flourish's picture
Update app.py
586c31c verified
raw
history blame
3.78 kB
import torch
import gradio as gr
import spaces
import random
import numpy as np
from pipeline import ChatsSDXLPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
from diffusers.utils import logging
from PIL import Image
logging.set_verbosity_error()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
# Load CHATS-SDXL pipeline
pipe = ChatsSDXLPipeline.from_pretrained(
"AIDC-AI/CHATS",
safety_checker=safety_checker,
feature_extractor=feature_extractor,
torch_dtype=torch.bfloat16
)
pipe.to(DEVICE)
@spaces.GPU(duration=75)
def generate(prompt, seed=0, randomize_seed=False, steps=50, guidance_scale=5.0):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
print('inference with prompt : {}, seed : {}, step : {}, cfg : {}'.format(prompt, seed, steps, guidance_scale))
output = pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=guidance_scale,
seed=seed
)
return output['images'][0]
examples = [
"Solar punk vehicle in a bustling city",
"An anthropomorphic cat riding a Harley Davidson in Arizona with sunglasses and a leather jacket",
"An elderly woman poses for a high fashion photoshoot in colorful, patterned clothes with a cyberpunk 2077 vibe",
]
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# CHATS-SDXL
SDXL diffusion models finetuned using preference optimization framework CHATS. [[paper](https://arxiv.org/pdf/2502.12579)] [[code](https://github.com/AIDC-AI/CHATS)] [[model](https://huggingface.co/AIDC-AI/CHATS)]
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt here",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=14,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=50,
)
gr.Examples(
examples = examples,
fn = generate,
inputs = [prompt],
outputs = [result],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn = generate,
inputs = [prompt, seed, randomize_seed, num_inference_steps, guidance_scale],
outputs = [result]
)
if __name__ == '__main__':
demo.launch()