MercuryNex commited on
Commit
64ceaa0
·
verified ·
1 Parent(s): 6f71b05

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import random
6
+ from huggingface_hub import snapshot_download
7
+ from diffusers import StableDiffusionXLPipeline, LCMScheduler
8
+ from PIL import Image
9
+
10
+ os.environ["XDG_CACHE_HOME"] = "/home/user/.cache"
11
+ os.environ["TRANSFORMERS_CACHE"] = "/home/user/.cache/huggingface/transformers"
12
+ os.environ["HF_HOME"] = "/home/user/.cache/huggingface"
13
+ models = [
14
+ "Niggendar/fastPhotoPony_v80MixB",
15
+ "Niggendar/realisticPonyPhoto_v10",
16
+ "Niggendar/realmix_v10",
17
+ "Niggendar/realmixpony_v01",
18
+ "Niggendar/realmixpony_v02",
19
+ "Niggendar/recondiff_v10",
20
+ "Niggendar/Regro",
21
+ "Niggendar/relhCheckpoint_v20",
22
+ ]
23
+ loras = ["openskyml/lcm-lora-sdxl-turbo"]
24
+ pipe = None
25
+ cached = {}
26
+ cached_loras = {}
27
+ def get_lora(lora_id):
28
+ if lora_id in cached_loras:
29
+ return cached_loras[lora_id]
30
+ lora_dir = snapshot_download(repo_id=lora_id, allow_patterns=["*.safetensors", "*.bin"])
31
+ lora_files = [f for f in os.listdir(lora_dir) if f.endswith((".safetensors", ".bin"))]
32
+ lora_path = os.path.join(lora_dir, lora_files[0])
33
+ cached_loras[lora_id] = lora_path
34
+ return lora_path
35
+ def load_pipe(model_id, lora_id):
36
+ global pipe
37
+ if (model_id, lora_id) in cached:
38
+ pipe = cached[(model_id, lora_id)]
39
+ return
40
+ if pipe is not None:
41
+ pipe.to("meta")
42
+ pipe.unet = None
43
+ pipe.vae = None
44
+ pipe.text_encoder = None
45
+ del pipe
46
+ gc.collect()
47
+ cached.clear()
48
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float32,low_cpu_mem_usage=True )
49
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
50
+ pipe.load_lora_weights(get_lora(lora_id))
51
+ pipe.to("cpu", dtype=torch.float32)
52
+ pipe.enable_attention_slicing()
53
+ cached[(model_id, lora_id)] = pipe
54
+ return gr.update(value='-')
55
+ def infer(model_id, lora_id, prompt, seed=None, steps=4, guid=0.1):
56
+ if seed is None or seed == "":
57
+ seed = random.randint(0, 2**32 - 1)
58
+ yield Image.new("RGB", (512, 512), color="gray"), gr.update(value='-')
59
+ image = pipe( prompt, generator=torch.manual_seed(int(seed)), num_inference_steps=steps,
60
+ guidance_scale=guid,width=128+256, height=128+256, added_cond_kwargs={} ).images[0]
61
+ yield image, gr.update(value='-')
62
+
63
+
64
+ with gr.Blocks() as demo:
65
+ with gr.Row():
66
+ with gr.Column(scale=2):
67
+ text2=gr.Textbox(label="Time",placeholder="timer",container=False,value='-')
68
+ mbtn=gr.Button(value="Load Pair")
69
+ modeldrop=gr.Dropdown(models, label="Model")
70
+ loradrop=gr.Dropdown(loras, label="LCM LoRA")
71
+ with gr.Accordion(label="Settings", open=False):
72
+ seed=gr.Textbox(label="Seed",visible=False)
73
+ steps=gr.Slider(1, 15, value=4, step=1, label="Steps")
74
+ guidance=gr.Slider(0.0, 2.0, value=0.1, step=0.1, label="Guidance Scale")
75
+ with gr.Column(scale=3):
76
+ text= gr.Textbox(label="Prompt",container=False,placeholder="Prompt",value='')
77
+ gbtn=gr.Button(value="Generate")
78
+ imageout=gr.Image()
79
+ mbtn.click(fn=load_pipe, inputs=[ modeldrop, loradrop ], outputs=[text2])
80
+ gbtn.click(fn=infer, inputs=[ modeldrop, loradrop, text, seed,steps, guidance ], outputs=[imageout,text2])
81
+
82
+ demo.queue()
83
+ demo.launch(server_name="0.0.0.0", server_port=7860)
84
+