|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
|
|
|
model_name = 'UnfilteredAI/NSFW-gen-v2' |
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
pipe.to('cuda') |
|
|
|
|
|
|
|
|
def build_embeddings(enhanced_prompt, negative_prompt=None): |
|
|
max_length = pipe.tokenizer.model_max_length |
|
|
|
|
|
input_ids = pipe.tokenizer(enhanced_prompt, return_tensors="pt").input_ids |
|
|
input_ids = input_ids.to("cuda") |
|
|
|
|
|
negative_ids = pipe.tokenizer( |
|
|
negative_prompt or "", |
|
|
truncation=False, |
|
|
padding="max_length", |
|
|
max_length=input_ids.shape[-1], |
|
|
return_tensors="pt" |
|
|
).input_ids |
|
|
negative_ids = negative_ids.to("cuda") |
|
|
|
|
|
concat_embeds = [] |
|
|
neg_embeds = [] |
|
|
for i in range(0, input_ids.shape[-1], max_length): |
|
|
concat_embeds.append(pipe.text_encoder(input_ids[:, i: i + max_length])[0]) |
|
|
neg_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0]) |
|
|
|
|
|
prompt_embeds = torch.cat(concat_embeds, dim=1) |
|
|
negative_prompt_embeds = torch.cat(neg_embeds, dim=1) |
|
|
return prompt_embeds, negative_prompt_embeds |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate(prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, num_samples): |
|
|
prompt_embeds, neg_prompt_embeds = build_embeddings(prompt, negative_prompt) |
|
|
return pipe( |
|
|
prompt_embeds=prompt_embeds, |
|
|
negative_prompt_embeds=negative_prompt_embeds, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
width=width, |
|
|
height=height, |
|
|
num_images_per_prompt=num_samples |
|
|
).images |
|
|
|
|
|
|
|
|
gr.Interface( |
|
|
fn=generate, |
|
|
inputs=[ |
|
|
gr.Text(label="Prompt"), |
|
|
gr.Text("", label="Negative Prompt"), |
|
|
gr.Number(7, label="Number inference steps"), |
|
|
gr.Number(3, label="Guidance scale"), |
|
|
gr.Number(512, label="Width"), |
|
|
gr.Number(512, label="Height"), |
|
|
gr.Number(1, label="# images"), |
|
|
], |
|
|
outputs=gr.Gallery(), |
|
|
).launch() |