lunarflu HF Staff commited on
Commit
7984c7c
Β·
1 Parent(s): 78a656e
Files changed (1) hide show
  1. app.py +292 -0
app.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import glob
3
+ import os
4
+ import pathlib
5
+ import random
6
+ import os
7
+ import random
8
+ import threading
9
+
10
+ import discord
11
+ import gradio as gr
12
+
13
+ import discord
14
+ from gradio_client import Client
15
+ from PIL import Image
16
+
17
+ from discord.ui import Button, View
18
+
19
+
20
+ #---------------------------------------------------------------------------------------------------------------------
21
+ event = Event()
22
+ HF_TOKEN = os.getenv("HF_TOKEN")
23
+ deepfloydif_client = Client("huggingface-projects/IF", HF_TOKEN)
24
+ #---------------------------------------------------------------------------------------------------------------------
25
+ DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")
26
+ intents = discord.Intents.default()
27
+ intents.message_content = True
28
+ bot = commands.Bot(command_prefix="/", intents=intents)
29
+ #---------------------------------------------------------------------------------------------------------------------
30
+
31
+ def deepfloydif_generate64_inference(prompt):
32
+ """Generates four images based on a prompt"""
33
+ negative_prompt = ""
34
+ seed = random.randint(0, 1000)
35
+ number_of_images = 4
36
+ guidance_scale = 7
37
+ custom_timesteps_1 = "smart50"
38
+ number_of_inference_steps = 50
39
+ (
40
+ stage_1_images,
41
+ stage_1_param_path,
42
+ path_for_upscale256_upscaling,
43
+ ) = deepfloydif_client.predict(
44
+ prompt,
45
+ negative_prompt,
46
+ seed,
47
+ number_of_images,
48
+ guidance_scale,
49
+ custom_timesteps_1,
50
+ number_of_inference_steps,
51
+ api_name="/generate64",
52
+ )
53
+ return [stage_1_images, stage_1_param_path, path_for_upscale256_upscaling]
54
+
55
+
56
+ def deepfloydif_upscale256_inference(index, path_for_upscale256_upscaling):
57
+ """Upscales one of the images from deepfloydif_generate64_inference based on the chosen index"""
58
+ selected_index_for_upscale256 = index
59
+ seed_2 = 0
60
+ guidance_scale_2 = 4
61
+ custom_timesteps_2 = "smart50"
62
+ number_of_inference_steps_2 = 50
63
+ result_path = deepfloydif_client.predict(
64
+ path_for_upscale256_upscaling,
65
+ selected_index_for_upscale256,
66
+ seed_2,
67
+ guidance_scale_2,
68
+ custom_timesteps_2,
69
+ number_of_inference_steps_2,
70
+ api_name="/upscale256",
71
+ )
72
+ return result_path
73
+
74
+
75
+ def deepfloydif_upscale1024_inference(index, path_for_upscale256_upscaling, prompt):
76
+ """Upscales to stage 2, then stage 3"""
77
+ selected_index_for_upscale256 = index
78
+ seed_2 = 0 # default seed for stage 2 256 upscaling
79
+ guidance_scale_2 = 4 # default for stage 2
80
+ custom_timesteps_2 = "smart50" # default for stage 2
81
+ number_of_inference_steps_2 = 50 # default for stage 2
82
+ negative_prompt = "" # empty (not used, could add in the future)
83
+
84
+ seed_3 = 0 # default for stage 3 1024 upscaling
85
+ guidance_scale_3 = 9 # default for stage 3
86
+ number_of_inference_steps_3 = 40 # default for stage 3
87
+
88
+ result_path = deepfloydif_client.predict(
89
+ path_for_upscale256_upscaling,
90
+ selected_index_for_upscale256,
91
+ seed_2,
92
+ guidance_scale_2,
93
+ custom_timesteps_2,
94
+ number_of_inference_steps_2,
95
+ prompt,
96
+ negative_prompt,
97
+ seed_3,
98
+ guidance_scale_3,
99
+ number_of_inference_steps_3,
100
+ api_name="/upscale1024",
101
+ )
102
+ return result_path
103
+
104
+
105
+ def load_image(png_files, stage_1_images):
106
+ """Opens images as variables so we can combine them later"""
107
+ results = []
108
+ for file in png_files:
109
+ png_path = os.path.join(stage_1_images, file)
110
+ results.append(Image.open(png_path))
111
+ return results
112
+
113
+
114
+ def combine_images(png_files, stage_1_images, partial_path):
115
+ if os.environ.get("TEST_ENV") == "True":
116
+ print("Combining images for deepfloydif_generate64")
117
+ images = load_image(png_files, stage_1_images)
118
+ combined_image = Image.new("RGB", (images[0].width * 2, images[0].height * 2))
119
+ combined_image.paste(images[0], (0, 0))
120
+ combined_image.paste(images[1], (images[0].width, 0))
121
+ combined_image.paste(images[2], (0, images[0].height))
122
+ combined_image.paste(images[3], (images[0].width, images[0].height))
123
+ combined_image_path = os.path.join(stage_1_images, f"{partial_path}.png")
124
+ combined_image.save(combined_image_path)
125
+ return combined_image_path
126
+
127
+
128
+ @bot.hybrid_command(
129
+ name="deepfloydif",
130
+ description="Enter a prompt to generate an image! Can generate realistic text, too!",
131
+ )
132
+ async def deepfloydif(ctx, prompt: str):
133
+ """DeepfloydIF stage 1 generation"""
134
+ try:
135
+ await deepfloydif_generate64(ctx, prompt, client)
136
+ except Exception as e:
137
+ print(f"Error: {e}")
138
+
139
+
140
+ async def deepfloydif_generate64(ctx, prompt, client):
141
+ """DeepfloydIF command (generate images with realistic text using slash commands)"""
142
+ try:
143
+ if os.environ.get("TEST_ENV") == "True":
144
+ print("Safety checks passed for deepfloydif_generate64")
145
+ channel = client.get_channel(DEEPFLOYDIF_CHANNEL_ID)
146
+ # interaction.response message can't be used to create a thread, so we create another message
147
+ message = await ctx.send(f"**{prompt}** - {ctx.author.mention} <a:loading:1114111677990981692>")
148
+
149
+ loop = asyncio.get_running_loop()
150
+ result = await loop.run_in_executor(None, deepfloydif_generate64_inference, prompt)
151
+ stage_1_images = result[0]
152
+ path_for_upscale256_upscaling = result[2]
153
+
154
+ partial_path = pathlib.Path(path_for_upscale256_upscaling).name
155
+ png_files = list(glob.glob(f"{stage_1_images}/**/*.png"))
156
+
157
+ if png_files:
158
+ await message.delete()
159
+ combined_image_path = combine_images(png_files, stage_1_images, partial_path)
160
+ if os.environ.get("TEST_ENV") == "True":
161
+ print("Images combined for deepfloydif_generate64")
162
+
163
+ with Image.open(combined_image_path) as img:
164
+ width, height = img.size
165
+ new_width = width * 3
166
+ new_height = height * 3
167
+ resized_img = img.resize((new_width, new_height))
168
+ x2_combined_image_path = combined_image_path
169
+ resized_img.save(x2_combined_image_path)
170
+
171
+ # making image bigger, more readable
172
+ with open(x2_combined_image_path, "rb") as f: # was combined_image_path
173
+ button1 = Button(custom_id="0", emoji="β†–")
174
+ button2 = Button(custom_id="1", emoji="β†—")
175
+ button3 = Button(custom_id="2", emoji="↙")
176
+ button4 = Button(custom_id="3", emoji="β†˜")
177
+
178
+ async def button_callback(interaction):
179
+ index = int(interaction.data["custom_id"]) # 0,1,2,3
180
+
181
+ await interaction.response.send_message(
182
+ f"{interaction.user.mention} <a:loading:1114111677990981692>", ephemeral=True
183
+ )
184
+ result_path = await deepfloydif_upscale256(index, path_for_upscale256_upscaling)
185
+
186
+ # create and use upscale 1024 button
187
+ with open(result_path, "rb") as f:
188
+ upscale1024 = Button(
189
+ label="High-quality upscale (x4)", custom_id=str(index)
190
+ ) # "0", "1" etc
191
+ upscale1024.callback = upscale1024_callback
192
+ view = View(timeout=None)
193
+ view.add_item(upscale1024)
194
+
195
+ await interaction.delete_original_response()
196
+ await channel.send(
197
+ content=(
198
+ f"{interaction.user.mention} Here is the upscaled image! Click the button"
199
+ " to upscale even more!"
200
+ ),
201
+ file=discord.File(f, f"{prompt}.png"),
202
+ view=view,
203
+ )
204
+
205
+ async def upscale1024_callback(interaction):
206
+ index = int(interaction.data["custom_id"])
207
+
208
+ await interaction.response.send_message(
209
+ f"{interaction.user.mention} <a:loading:1114111677990981692>", ephemeral=True
210
+ )
211
+ result_path = await deepfloydif_upscale1024(index, path_for_upscale256_upscaling, prompt)
212
+
213
+ with open(result_path, "rb") as f:
214
+ await interaction.delete_original_response()
215
+ await channel.send(
216
+ content=f"{interaction.user.mention} Here's your high-quality x16 image!",
217
+ file=discord.File(f, f"{prompt}.png"),
218
+ )
219
+
220
+ button1.callback = button_callback
221
+ button2.callback = button_callback
222
+ button3.callback = button_callback
223
+ button4.callback = button_callback
224
+
225
+ view = View(timeout=None)
226
+ view.add_item(button1)
227
+ view.add_item(button2)
228
+ view.add_item(button3)
229
+ view.add_item(button4)
230
+
231
+ # could store this message as combined_image_dfif in case it's useful for future testing
232
+ await channel.send(
233
+ f"**{prompt}** - {ctx.author.mention} Click a button to upscale! (make larger + enhance"
234
+ " quality)",
235
+ file=discord.File(f, f"{partial_path}.png"),
236
+ view=view,
237
+ )
238
+ else:
239
+ await ctx.send(f"{ctx.author.mention} No PNG files were found, cannot post them!")
240
+
241
+ except Exception as e:
242
+ print(f"Error: {e}")
243
+
244
+
245
+ async def deepfloydif_upscale256(index: int, path_for_upscale256_upscaling):
246
+ """upscaling function for images generated using /deepfloydif"""
247
+ try:
248
+ loop = asyncio.get_running_loop()
249
+ result_path = await loop.run_in_executor(
250
+ None, deepfloydif_upscale256_inference, index, path_for_upscale256_upscaling
251
+ )
252
+ return result_path
253
+
254
+ except Exception as e:
255
+ print(f"Error: {e}")
256
+
257
+
258
+ async def deepfloydif_upscale1024(index: int, path_for_upscale256_upscaling, prompt):
259
+ """upscaling function for images generated using /deepfloydif"""
260
+ try:
261
+ loop = asyncio.get_running_loop()
262
+ result_path = await loop.run_in_executor(
263
+ None, deepfloydif_upscale1024_inference, index, path_for_upscale256_upscaling, prompt
264
+ )
265
+ return result_path
266
+
267
+ except Exception as e:
268
+ print(f"Error: {e}")
269
+
270
+
271
+ #---------------------------------------------------------------------------------------------------------------------
272
+ def run_bot():
273
+ if not DISCORD_TOKEN:
274
+ print("DISCORD_TOKEN NOT SET")
275
+ event.set()
276
+ else:
277
+ bot.run(DISCORD_TOKEN)
278
+
279
+
280
+ threading.Thread(target=run_bot).start()
281
+
282
+ event.wait()
283
+
284
+ with gr.Blocks() as demo:
285
+ gr.Markdown(
286
+ """
287
+ # Discord bot of https://huggingface.co/spaces/facebook/MusicGen
288
+ https://discord.com/api/oauth2/authorize?client_id=1151888750662664222&permissions=309237696512&scope=bot
289
+ """
290
+ )
291
+
292
+ demo.launch()