|
import torch |
|
import gradio as gr |
|
|
|
from pipeline import ChatsSDXLPipeline |
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
|
from transformers import CLIPFeatureExtractor |
|
from diffusers.utils import logging |
|
from PIL import Image |
|
|
|
logging.set_verbosity_error() |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") |
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") |
|
|
|
|
|
pipe = ChatsSDXLPipeline.from_pretrained( |
|
"AIDC-AI/CHATS", |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor, |
|
torch_dtype=torch.float16 |
|
) |
|
pipe.to(DEVICE) |
|
|
|
def generate(prompt, steps=50, guidance_scale=7.5, height=768, width=512): |
|
output = pipe( |
|
prompt=prompt, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance_scale, |
|
height=height, |
|
width=width, |
|
seed=0 |
|
) |
|
image = output['images'][0] |
|
image = Image.fromarray(image) |
|
return image |
|
|
|
with gr.Blocks(title="π₯ CHATS-SDXL Demo") as demo: |
|
gr.Markdown( |
|
"## CHATS-SDXL Text-to-Image Demo\n\n" |
|
"Enter your prompt and click **Generate Image**. All NSFW content will be automatically filtered." |
|
) |
|
with gr.Row(): |
|
prompt_input = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Enter your description here...", |
|
lines=2, |
|
) |
|
with gr.Row(): |
|
steps_slider = gr.Slider( |
|
minimum=1, maximum=100, value=50, step=1, |
|
label="Inference Steps" |
|
) |
|
scale_slider = gr.Slider( |
|
minimum=1.0, maximum=14.0, value=5.0, step=0.1, |
|
label="Guidance Scale" |
|
) |
|
with gr.Row(): |
|
height_slider = gr.Slider( |
|
minimum=64, maximum=2048, value=1024, step=64, |
|
label="Image Height" |
|
) |
|
width_slider = gr.Slider( |
|
minimum=64, maximum=2048, value=1024, step=64, |
|
label="Image Width" |
|
) |
|
generate_button = gr.Button("Generate Image") |
|
gallery = gr.Gallery( |
|
label="Generated Images", |
|
show_label=False, |
|
columns=2, |
|
elem_id="gallery" |
|
) |
|
|
|
generate_button.click( |
|
fn=generate, |
|
inputs=[prompt_input, steps_slider, scale_slider, height_slider, width_slider], |
|
outputs=[gallery], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|