Spaces:
Runtime error
Runtime error
File size: 2,440 Bytes
7b664dc ffe9258 7b664dc 2123fad 78e6f58 2123fad f9b26a5 e88e5da 7b664dc 9973325 e88e5da 9973325 f3a7387 78e6f58 f3a7387 7b664dc 27b9508 7b664dc 27b9508 6d847d3 7b664dc 27b9508 7b664dc 78e6f58 7b664dc 9008411 ffe9258 9008411 78e6f58 ffe9258 d47b997 7b664dc 78e6f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
#!/usr/bin/env python
# coding: utf-8
import os
import openai
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from torch import autocast
#from PIL import Image
#from torchvision import transforms
#from diffusers import StableDiffusionImageVariationPipeline
openai.api_key = os.getenv('openaikey')
authtoken = os.getenv('authtoken')
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained("stale2000/sd-dnditem", torch_dtype=torch.float16, use_auth_token=authtoken)
pipe = pipe.to(device)
def predict(input, manual_query_repacement, history=[]):
# gpt3
if manual_query_repacement != "":
input = manual_query_repacement
response = openai.Completion.create(
model="text-davinci-003",
prompt=input,
temperature=0.9,
max_tokens=150,
top_p=1,
frequency_penalty=0,
presence_penalty=0.6)
# tokenize the new input sentence
responseText = response["choices"][0]["text"]
history.append((input, responseText))
#img generation
prompt = "Yoda"
scale = 10
n_samples = 4
# Sometimes the nsfw checker is confused by the Naruto images, you can disable
# it at your own risk here
#disable_safety = False
#if disable_safety:
# def null_safety(images, **kwargs):
# return images, False
# pipe.safety_checker = null_safety
with autocast("cuda"):
images = pipe(n_samples*[prompt], guidance_scale=scale).images
for idx, im in enumerate(images):
im.save(f"{idx:06}.png")
images_list = pipe(
inp.tile(n_samples, 1, 1, 1),
guidance_scale=scale,
num_inference_steps=steps,
generator=generator,
)
images = []
for i, image in enumerate(images_list["images"]):
if(images_list["nsfw_content_detected"][i]):
safe_image = Image.open(r"unsafe.png")
images.append(safe_image)
else:
images.append(image)
return history, history, images
inputText = gr.Textbox(value="tmp")
manual_query = gr.Textbox(placeholder="Input any query here, to replace the image generation query builder entirely.")
output_img = gr.Gallery(label="Generated image")
output_img.style(grid=2)
gr.Interface(fn=predict,
inputs=[inputText,manual_query,'state'],
outputs=["chatbot",'state', output_img]).launch()
|