awacke1 commited on
Commit
2f209b3
·
verified ·
1 Parent(s): 88b2855

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import os
4
+ import random
5
+ import uuid
6
+ import base64
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ import streamlit as st
11
+
12
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
13
+
14
+ DESCRIPTION = """# DALL•E 3 XL v2 High Fi"""
15
+
16
+ def create_download_link(filename):
17
+ with open(filename, "rb") as file:
18
+ encoded_string = base64.b64encode(file.read()).decode('utf-8')
19
+ download_link = f'<a href="data:image/png;base64,{encoded_string}" download="{filename}">Download Image</a>'
20
+ return download_link
21
+
22
+ def save_image(img, prompt):
23
+ unique_name = str(uuid.uuid4()) + ".png"
24
+ img.save(unique_name)
25
+
26
+ # save with prompt to save prompt as image file name
27
+ filename = f"{prompt}.png"
28
+ img.save(filename)
29
+ return filename
30
+
31
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
32
+ if randomize_seed:
33
+ seed = random.randint(0, MAX_SEED)
34
+ return seed
35
+
36
+ MAX_SEED = np.iinfo(np.int32).max
37
+
38
+ if not torch.cuda.is_available():
39
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
40
+
41
+ MAX_SEED = np.iinfo(np.int32).max
42
+
43
+ USE_TORCH_COMPILE = 0
44
+ ENABLE_CPU_OFFLOAD = 0
45
+
46
+
47
+ if torch.cuda.is_available():
48
+ pipe = StableDiffusionXLPipeline.from_pretrained(
49
+ "fluently/Fluently-XL-v4",
50
+ torch_dtype=torch.float16,
51
+ use_safetensors=True,
52
+ )
53
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
54
+
55
+
56
+ pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle")
57
+ pipe.set_adapters("dalle")
58
+
59
+ pipe.to("cuda")
60
+
61
+
62
+ def generate(
63
+ prompt: str,
64
+ negative_prompt: str = "",
65
+ use_negative_prompt: bool = False,
66
+ seed: int = 0,
67
+ width: int = 1024,
68
+ height: int = 1024,
69
+ guidance_scale: float = 3,
70
+ randomize_seed: bool = False,
71
+ ):
72
+
73
+
74
+ seed = int(randomize_seed_fn(seed, randomize_seed))
75
+
76
+ if not use_negative_prompt:
77
+ negative_prompt = "" # type: ignore
78
+
79
+ images = pipe(
80
+ prompt=prompt,
81
+ negative_prompt=negative_prompt,
82
+ width=width,
83
+ height=height,
84
+ guidance_scale=guidance_scale,
85
+ num_inference_steps=20,
86
+ num_images_per_prompt=1,
87
+ cross_attention_kwargs={"scale": 0.65},
88
+ output_type="pil",
89
+ ).images
90
+ image_paths = [save_image(img, prompt) for img in images]
91
+
92
+ download_links = [create_download_link(path) for path in image_paths]
93
+
94
+ print(image_paths)
95
+ return image_paths, seed, download_links
96
+
97
+ examples = [
98
+ "a modern hospital room with advanced medical equipment and a patient resting comfortably",
99
+ "a team of surgeons performing a delicate operation using state-of-the-art surgical robots",
100
+ "a elderly woman smiling while a nurse checks her vital signs using a holographic display",
101
+ "a child receiving a painless vaccination from a friendly robot nurse in a colorful pediatric clinic",
102
+ "a group of researchers working in a high-tech laboratory, developing new treatments for rare diseases",
103
+ "a telemedicine consultation between a doctor and a patient, using virtual reality technology for a immersive experience"
104
+ ]
105
+
106
+ st.set_page_config(page_title="DALL•E 3 XL v2 High Fi", layout="centered")
107
+ st.markdown(DESCRIPTION)
108
+
109
+ with st.form(key="generation_form"):
110
+ prompt = st.text_input("Prompt", max_chars=None, placeholder="Enter your prompt")
111
+ use_negative_prompt = st.checkbox("Use negative prompt", value=True)
112
+ if use_negative_prompt:
113
+ negative_prompt = st.text_area(
114
+ "Negative prompt",
115
+ value="""(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, (NSFW:1.25)""",
116
+ placeholder="Enter a negative prompt",
117
+ height=100,
118
+ )
119
+ else:
120
+ negative_prompt = ""
121
+
122
+ col1, col2 = st.columns(2)
123
+ with col1:
124
+ width = st.slider("Width", min_value=512, max_value=2048, step=8, value=1920)
125
+ with col2:
126
+ height = st.slider("Height", min_value=512, max_value=2048, step=8, value=1080)
127
+
128
+ col3, col4 = st.columns(2)
129
+ with col3:
130
+ guidance_scale = st.slider("Guidance Scale", min_value=0.1, max_value=20.0, step=0.1, value=20.0)
131
+ with col4:
132
+ randomize_seed = st.checkbox("Randomize seed", value=True)
133
+ if not randomize_seed:
134
+ seed = st.slider("Seed", min_value=0, max_value=MAX_SEED, step=1, value=0)
135
+ else:
136
+ seed = 0
137
+
138
+ run_button = st.form_submit_button("Run")
139
+
140
+ if run_button:
141
+ image_paths, seed, download_links = generate(
142
+ prompt=prompt,
143
+ negative_prompt=negative_prompt,
144
+ use_negative_prompt=use_negative_prompt,
145
+ seed=seed,
146
+ width=width,
147
+ height=height,
148
+ guidance_scale=guidance_scale,
149
+ randomize_seed=randomize_seed,
150
+ )
151
+
152
+ for image_path in image_paths:
153
+ st.image(image_path, caption=prompt)
154
+
155
+ for download_link in download_links:
156
+ st.markdown(download_link, unsafe_allow_html=True)
157
+
158
+ st.text(f"Seed: {seed}")
159
+
160
+ st.subheader("Examples")
161
+ for example in examples:
162
+ st.button(example, key=example, on_click=prompt.setter(example))