File size: 1,803 Bytes
210ed13
62c5b0c
 
210ed13
 
62c5b0c
210ed13
62c5b0c
 
 
 
 
 
210ed13
 
 
 
 
62c5b0c
210ed13
 
62c5b0c
210ed13
 
62c5b0c
 
 
 
 
 
 
210ed13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62c5b0c
 
210ed13
 
 
 
 
 
 
 
 
 
62c5b0c
210ed13
 
62c5b0c
210ed13
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import spaces
from tempfile import NamedTemporaryFile
import numpy as np
import random
from diffusers import StableDiffusionPipeline as DiffusionPipeline
import torch
from pathos.multiprocessing import ProcessingPool as ProcessPoolExecutor

pool = ProcessPoolExecutor(100)
pool.__enter__()

model_id = "runwayml/stable-diffusion-v1-5"

device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.cuda.is_available():
    torch.cuda.max_memory_allocated(device=device)
    pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
    pipe = pipe.to(device)
else: 
    pipe = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
    pipe = pipe.to(device)

@spaces.GPU(10)
def infer(prompt):
    image = pipe(prompt).images[0] 
    ret = None
    with NamedTemporaryFile("wb", suffix=".png", delete=False) as file:
        ret = file.name
    return ret

css="""
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

if torch.cuda.is_available():
    power_device = "GPU"
else:
    power_device = "CPU"

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
            # Image Generator
            Currently running on {power_device}.
        """)
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        result = gr.Image(label="Result", show_label=False, type='filepath')
    run_button.click(
        fn = infer,
        inputs = [prompt],
        outputs = [result]
    )

demo.queue().launch()