CHATS / app.py
Flourish's picture
Upload 5 files
c7db14f verified
raw
history blame
2.5 kB
import torch
import gradio as gr
from pipeline import ChatsSDXLPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
from diffusers.utils import logging
from PIL import Image
logging.set_verbosity_error()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
# Load CHATS-SDXL pipeline
pipe = ChatsSDXLPipeline.from_pretrained(
"AIDC-AI/CHATS",
safety_checker=safety_checker,
feature_extractor=feature_extractor,
torch_dtype=torch.float16
)
pipe.to(DEVICE)
def generate(prompt, steps=50, guidance_scale=7.5, height=768, width=512):
output = pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=guidance_scale,
height=height,
width=width,
seed=0
)
image = output['images'][0]
image = Image.fromarray(image)
return image
with gr.Blocks(title="πŸ”₯ CHATS-SDXL Demo") as demo:
gr.Markdown(
"## CHATS-SDXL Text-to-Image Demo\n\n"
"Enter your prompt and click **Generate Image**. All NSFW content will be automatically filtered."
)
with gr.Row():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your description here...",
lines=2,
)
with gr.Row():
steps_slider = gr.Slider(
minimum=1, maximum=100, value=50, step=1,
label="Inference Steps"
)
scale_slider = gr.Slider(
minimum=1.0, maximum=14.0, value=5.0, step=0.1,
label="Guidance Scale"
)
with gr.Row():
height_slider = gr.Slider(
minimum=64, maximum=2048, value=1024, step=64,
label="Image Height"
)
width_slider = gr.Slider(
minimum=64, maximum=2048, value=1024, step=64,
label="Image Width"
)
generate_button = gr.Button("Generate Image")
gallery = gr.Gallery(
label="Generated Images",
show_label=False,
columns=2,
elem_id="gallery"
)
generate_button.click(
fn=generate,
inputs=[prompt_input, steps_slider, scale_slider, height_slider, width_slider],
outputs=[gallery],
)
if __name__ == "__main__":
demo.launch()