File size: 2,717 Bytes
7b664dc
 
 
 
 
ffe9258
7b664dc
2123fad
 
 
23c50d2
 
78e6f58
 
 
 
2123fad
 
f9b26a5
e88e5da
7b664dc
9973325
23c50d2
150be19
 
9973325
 
f3a7387
 
78e6f58
f3a7387
 
 
7b664dc
 
27b9508
7b664dc
 
 
 
27b9508
6d847d3
7b664dc
27b9508
 
7b664dc
78e6f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23c50d2
 
 
78e6f58
23c50d2
 
 
78e6f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b664dc
 
9008411
ffe9258
 
9008411
78e6f58
 
 
ffe9258
d47b997
7b664dc
78e6f58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python
# coding: utf-8

import os
import openai
import gradio as gr

import torch
from diffusers import StableDiffusionPipeline
from torch import autocast

from contextlib import nullcontext
#from PIL import Image
#from torchvision import transforms

#from diffusers import StableDiffusionImageVariationPipeline


openai.api_key = os.getenv('openaikey')
authtoken = os.getenv('authtoken')

device = "cuda" if torch.cuda.is_available() else "cpu"
context = autocast if device == "cuda" else nullcontext
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionPipeline.from_pretrained("stale2000/sd-dnditem", torch_dtype=dtype, use_auth_token=authtoken)  
pipe = pipe.to(device)

def predict(input, manual_query_repacement, history=[]):

    # gpt3
    if manual_query_repacement != "":
        input = manual_query_repacement
    
    response = openai.Completion.create(
    model="text-davinci-003",
    prompt=input,
    temperature=0.9,
    max_tokens=150,
    top_p=1,
    frequency_penalty=0,
    presence_penalty=0.6)
    
    # tokenize the new input sentence
    responseText = response["choices"][0]["text"]
    history.append((input, responseText))


    #img generation
    prompt = "Yoda"
    scale = 10
    n_samples = 4

    # Sometimes the nsfw checker is confused by the Naruto images, you can disable
    # it at your own risk here
    #disable_safety = False

    #if disable_safety:
    #  def null_safety(images, **kwargs):
    #      return images, False
    #  pipe.safety_checker = null_safety

    
    #with autocast("cuda"):
    #  images = pipe(n_samples*[prompt], guidance_scale=scale).images

    with context("cuda"):
        images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=5).images
    
    for idx, im in enumerate(images):
      im.save(f"{idx:06}.png")

    images_list = pipe(
        inp.tile(n_samples, 1, 1, 1),
        guidance_scale=scale,
        num_inference_steps=steps,
        generator=generator,
        )

    images = []
    for i, image in enumerate(images_list["images"]):
        if(images_list["nsfw_content_detected"][i]):
            safe_image = Image.open(r"unsafe.png")
            images.append(safe_image)
        else:
            images.append(image)


    
    return history, history, images



inputText = gr.Textbox(value="tmp")
manual_query = gr.Textbox(placeholder="Input any query here, to replace the image generation query builder entirely.")

output_img = gr.Gallery(label="Generated image")
output_img.style(grid=2)

gr.Interface(fn=predict,
             inputs=[inputText,manual_query,'state'],
            
             outputs=["chatbot",'state', output_img]).launch()