File size: 2,440 Bytes
7b664dc
 
 
 
 
ffe9258
7b664dc
2123fad
 
 
78e6f58
 
 
 
2123fad
 
f9b26a5
e88e5da
7b664dc
9973325
e88e5da
9973325
 
f3a7387
 
78e6f58
f3a7387
 
 
7b664dc
 
27b9508
7b664dc
 
 
 
27b9508
6d847d3
7b664dc
27b9508
 
7b664dc
78e6f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b664dc
 
9008411
ffe9258
 
9008411
78e6f58
 
 
ffe9258
d47b997
7b664dc
78e6f58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python
# coding: utf-8

import os
import openai
import gradio as gr

import torch
from diffusers import StableDiffusionPipeline
from torch import autocast
#from PIL import Image
#from torchvision import transforms

#from diffusers import StableDiffusionImageVariationPipeline


openai.api_key = os.getenv('openaikey')
authtoken = os.getenv('authtoken')

device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained("stale2000/sd-dnditem", torch_dtype=torch.float16, use_auth_token=authtoken)  
pipe = pipe.to(device)

def predict(input, manual_query_repacement, history=[]):

    # gpt3
    if manual_query_repacement != "":
        input = manual_query_repacement
    
    response = openai.Completion.create(
    model="text-davinci-003",
    prompt=input,
    temperature=0.9,
    max_tokens=150,
    top_p=1,
    frequency_penalty=0,
    presence_penalty=0.6)
    
    # tokenize the new input sentence
    responseText = response["choices"][0]["text"]
    history.append((input, responseText))


    #img generation
    prompt = "Yoda"
    scale = 10
    n_samples = 4

    # Sometimes the nsfw checker is confused by the Naruto images, you can disable
    # it at your own risk here
    #disable_safety = False

    #if disable_safety:
    #  def null_safety(images, **kwargs):
    #      return images, False
    #  pipe.safety_checker = null_safety

    with autocast("cuda"):
      images = pipe(n_samples*[prompt], guidance_scale=scale).images

    for idx, im in enumerate(images):
      im.save(f"{idx:06}.png")

    images_list = pipe(
        inp.tile(n_samples, 1, 1, 1),
        guidance_scale=scale,
        num_inference_steps=steps,
        generator=generator,
        )

    images = []
    for i, image in enumerate(images_list["images"]):
        if(images_list["nsfw_content_detected"][i]):
            safe_image = Image.open(r"unsafe.png")
            images.append(safe_image)
        else:
            images.append(image)


    
    return history, history, images



inputText = gr.Textbox(value="tmp")
manual_query = gr.Textbox(placeholder="Input any query here, to replace the image generation query builder entirely.")

output_img = gr.Gallery(label="Generated image")
output_img.style(grid=2)

gr.Interface(fn=predict,
             inputs=[inputText,manual_query,'state'],
            
             outputs=["chatbot",'state', output_img]).launch()