ginipick commited on
Commit
373a32f
ยท
verified ยท
1 Parent(s): d85f4d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -56
app.py CHANGED
@@ -4,13 +4,38 @@ import os
4
  import logging
5
  import json
6
  from datetime import datetime
 
 
 
7
  import shutil
 
 
 
 
 
 
8
 
9
  # ๋กœ๊น… ์„ค์ •
10
  logging.basicConfig(level=logging.INFO)
11
 
12
- # API ํด๋ผ์ด์–ธํŠธ ์„ค์ •
13
- api_client = Client("http://211.233.58.202:7960/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # ๊ฐค๋Ÿฌ๋ฆฌ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
16
  GALLERY_DIR = "gallery"
@@ -19,16 +44,16 @@ GALLERY_JSON = "gallery.json"
19
  # ๊ฐค๋Ÿฌ๋ฆฌ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
20
  os.makedirs(GALLERY_DIR, exist_ok=True)
21
 
22
- def save_to_gallery(image_path, prompt):
23
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
24
- new_image_path = os.path.join(GALLERY_DIR, f"{timestamp}.png")
25
 
26
- # ์ด๋ฏธ์ง€ ํŒŒ์ผ ๋ณต์‚ฌ
27
- shutil.copy2(image_path, new_image_path)
28
 
29
  # ๊ฐค๋Ÿฌ๋ฆฌ ์ •๋ณด ์ €์žฅ
30
  gallery_info = {
31
- "image": new_image_path,
32
  "prompt": prompt,
33
  "timestamp": timestamp
34
  }
@@ -44,44 +69,55 @@ def save_to_gallery(image_path, prompt):
44
  with open(GALLERY_JSON, "w") as f:
45
  json.dump(gallery, f, indent=2)
46
 
47
- return new_image_path
48
 
49
  def load_gallery():
50
  if os.path.exists(GALLERY_JSON):
51
  with open(GALLERY_JSON, "r") as f:
52
  gallery = json.load(f)
53
- return [(item["image"], item["prompt"]) for item in reversed(gallery)]
54
  return []
55
 
56
- def respond(message, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
57
- logging.info(f"Received message: {message}, seed: {seed}, randomize_seed: {randomize_seed}, "
58
- f"width: {width}, height: {height}, guidance_scale: {guidance_scale}, "
59
- f"num_inference_steps: {num_inference_steps}")
 
 
 
 
 
 
60
 
61
  try:
62
- # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์š”์ฒญ
 
 
 
 
 
 
63
  result = api_client.predict(
64
- prompt=message,
 
 
 
 
 
65
  seed=seed,
66
- randomize_seed=randomize_seed,
67
- width=width,
68
- height=height,
69
- guidance_scale=guidance_scale,
70
- num_inference_steps=num_inference_steps,
71
- api_name="/infer_t2i"
72
  )
73
  logging.info("API response received: %s", result)
74
-
75
- # ๊ฒฐ๊ณผ ํ™•์ธ ๋ฐ ์ฒ˜๋ฆฌ
76
- if isinstance(result, tuple) and len(result) >= 1:
77
- image_path = result[0]
78
- saved_image_path = save_to_gallery(image_path, message)
79
- return saved_image_path
80
  else:
81
- raise ValueError("Unexpected API response format")
82
  except Exception as e:
83
- logging.error("Error during API request: %s", str(e))
84
- return "Failed to generate image due to an error."
85
 
86
  css = """
87
  footer {
@@ -119,64 +155,65 @@ examples = [
119
  ["A fantasy map of a fictional world, with detailed terrain and cities.", "q19.webp"]
120
  ]
121
 
122
- def use_prompt(prompt):
123
- return prompt
 
 
 
 
124
 
125
  with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
126
  with gr.Tab("Generate"):
127
  with gr.Row():
128
- input_text = gr.Textbox(label="Enter your prompt for image generation")
129
- output_image = gr.Image(label="Generated Image")
 
130
 
131
  with gr.Row():
 
 
 
 
132
  seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="Seed", value=123)
133
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
134
-
135
- with gr.Row():
136
- width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=1024)
137
- height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=576)
138
-
139
- with gr.Row():
140
- guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.1, label="Guidance Scale", value=5)
141
- num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Number of Inference Steps", value=28)
142
 
143
  with gr.Row():
144
  for prompt, image_file in examples:
145
  with gr.Column():
146
  gr.Image(image_file, label=prompt[:50] + "...")
147
- gr.Button("Use this prompt").click(
148
- fn=use_prompt,
149
  inputs=[],
150
- outputs=input_text,
151
  api_name=False
152
  ).then(
153
- lambda x=prompt: x,
154
  inputs=[],
155
- outputs=input_text
156
  )
157
-
158
  with gr.Tab("Gallery"):
159
  gallery = gr.Gallery(
160
- label="Generated Images",
161
  show_label=False,
162
  elem_id="gallery",
163
  columns=[5],
164
  rows=[3],
165
  object_fit="contain",
166
- height="auto"
 
167
  )
 
168
  refresh_btn = gr.Button("Refresh Gallery")
169
 
170
- def update_gallery():
171
- return load_gallery()
172
-
173
  refresh_btn.click(fn=update_gallery, inputs=None, outputs=gallery)
174
  demo.load(fn=update_gallery, inputs=None, outputs=gallery)
 
175
 
176
  input_text.submit(
177
  fn=respond,
178
- inputs=[input_text, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
179
- outputs=output_image
180
  ).then(
181
  fn=update_gallery,
182
  inputs=None,
 
4
  import logging
5
  import json
6
  from datetime import datetime
7
+ import tempfile
8
+ import numpy as np
9
+ from PIL import Image
10
  import shutil
11
+ import httpx
12
+ import time
13
+ import base64
14
+ from gradio_client import Client, handle_file
15
+ import cv2
16
+ from moviepy.editor import VideoFileClip
17
 
18
  # ๋กœ๊น… ์„ค์ •
19
  logging.basicConfig(level=logging.INFO)
20
 
21
+ # ํƒ€์ž„์•„์›ƒ ์„ค์ •์„ 30์ดˆ๋กœ ๋Š˜๋ฆผ
22
+ httpx_client = httpx.Client(timeout=30.0)
23
+
24
+ max_retries = 3
25
+ retry_delay = 5 # 5์ดˆ ๋Œ€๊ธฐ
26
+
27
+ for attempt in range(max_retries):
28
+ try:
29
+ api_client = Client("http://211.233.58.202:7960/")
30
+ api_client.httpx_client = httpx_client # httpx ํด๋ผ์ด์–ธํŠธ ์„ค์ •
31
+ break # ์„ฑ๊ณตํ•˜๋ฉด ๋ฃจํ”„ ์ข…๋ฃŒ
32
+ except httpx.ReadTimeout:
33
+ if attempt < max_retries - 1: # ๋งˆ์ง€๋ง‰ ์‹œ๋„๊ฐ€ ์•„๋‹ˆ๋ฉด
34
+ print(f"Connection timed out. Retrying in {retry_delay} seconds...")
35
+ time.sleep(retry_delay)
36
+ else:
37
+ print("Failed to connect after multiple attempts.")
38
+ raise # ๋ชจ๋“  ์‹œ๋„ ์‹คํŒจ ์‹œ ์˜ˆ์™ธ ๋ฐœ์ƒ
39
 
40
  # ๊ฐค๋Ÿฌ๋ฆฌ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
41
  GALLERY_DIR = "gallery"
 
44
  # ๊ฐค๋Ÿฌ๋ฆฌ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
45
  os.makedirs(GALLERY_DIR, exist_ok=True)
46
 
47
+ def save_to_gallery(video_path, prompt):
48
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
49
+ new_video_path = os.path.join(GALLERY_DIR, f"{timestamp}.mp4")
50
 
51
+ # ๋น„๋””์˜ค ํŒŒ์ผ ๋ณต์‚ฌ
52
+ shutil.copy2(video_path, new_video_path)
53
 
54
  # ๊ฐค๋Ÿฌ๋ฆฌ ์ •๋ณด ์ €์žฅ
55
  gallery_info = {
56
+ "video": new_video_path,
57
  "prompt": prompt,
58
  "timestamp": timestamp
59
  }
 
69
  with open(GALLERY_JSON, "w") as f:
70
  json.dump(gallery, f, indent=2)
71
 
72
+ return new_video_path
73
 
74
  def load_gallery():
75
  if os.path.exists(GALLERY_JSON):
76
  with open(GALLERY_JSON, "r") as f:
77
  gallery = json.load(f)
78
+ return [{"image": item["video"], "caption": item["prompt"]} for item in reversed(gallery)]
79
  return []
80
 
81
+ def update_gallery():
82
+ gallery_items = load_gallery()
83
+ return [
84
+ {"video": item['image'], "caption": item['caption']}
85
+ for item in gallery_items
86
+ ]
87
+
88
+ def respond(image, prompt, steps, cfg_scale, eta, fs, seed, video_length):
89
+ logging.info(f"Received prompt: {prompt}, steps: {steps}, cfg_scale: {cfg_scale}, "
90
+ f"eta: {eta}, fs: {fs}, seed: {seed}, video_length: {video_length}")
91
 
92
  try:
93
+ # ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ฒ˜๋ฆฌ
94
+ if image is not None:
95
+ image_file = handle_file(image)
96
+ else:
97
+ image_file = None
98
+
99
+ # ๋น„๋””์˜ค ์ƒ์„ฑ ์š”์ฒญ
100
  result = api_client.predict(
101
+ image=image_file,
102
+ prompt=prompt,
103
+ steps=steps,
104
+ cfg_scale=cfg_scale,
105
+ eta=eta,
106
+ fs=fs,
107
  seed=seed,
108
+ video_length=video_length,
109
+ api_name="/infer"
 
 
 
 
110
  )
111
  logging.info("API response received: %s", result)
112
+
113
+ if isinstance(result, dict) and 'video' in result:
114
+ saved_video_path = save_to_gallery(result['video'], prompt)
115
+ return saved_video_path
 
 
116
  else:
117
+ raise ValueError("์˜ˆ์ƒ์น˜ ๋ชปํ•œ API ์‘๋‹ต ํ˜•์‹")
118
  except Exception as e:
119
+ logging.error("API ์š”์ฒญ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: %s", str(e))
120
+ return "์˜ค๋ฅ˜๋กœ ์ธํ•ด ๋น„๋””์˜ค ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."
121
 
122
  css = """
123
  footer {
 
155
  ["A fantasy map of a fictional world, with detailed terrain and cities.", "q19.webp"]
156
  ]
157
 
158
+
159
+ def use_prompt_and_image(prompt, image):
160
+ return prompt, image
161
+
162
+ def show_video(evt: gr.SelectData):
163
+ return evt.value["video"] if isinstance(evt.value, dict) and "video" in evt.value else None
164
 
165
  with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
166
  with gr.Tab("Generate"):
167
  with gr.Row():
168
+ input_image = gr.Image(label="Upload an image", type="filepath")
169
+ input_text = gr.Textbox(label="Enter your prompt for video generation")
170
+ output_video = gr.Video(label="Generated Video")
171
 
172
  with gr.Row():
173
+ steps = gr.Slider(minimum=1, maximum=100, step=1, label="Steps", value=30)
174
+ cfg_scale = gr.Slider(minimum=1, maximum=15, step=0.1, label="CFG Scale", value=3.5)
175
+ eta = gr.Slider(minimum=0, maximum=1, step=0.1, label="ETA", value=1)
176
+ fs = gr.Slider(minimum=1, maximum=30, step=1, label="FPS", value=8)
177
  seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="Seed", value=123)
178
+ video_length = gr.Slider(minimum=1, maximum=10, step=1, label="Video Length (seconds)", value=2)
 
 
 
 
 
 
 
 
179
 
180
  with gr.Row():
181
  for prompt, image_file in examples:
182
  with gr.Column():
183
  gr.Image(image_file, label=prompt[:50] + "...")
184
+ gr.Button("Use this example").click(
185
+ fn=use_prompt_and_image,
186
  inputs=[],
187
+ outputs=[input_text, input_image],
188
  api_name=False
189
  ).then(
190
+ lambda p=prompt, i=image_file: (p, i),
191
  inputs=[],
192
+ outputs=[input_text, input_image]
193
  )
194
+
195
  with gr.Tab("Gallery"):
196
  gallery = gr.Gallery(
197
+ label="Generated Videos",
198
  show_label=False,
199
  elem_id="gallery",
200
  columns=[5],
201
  rows=[3],
202
  object_fit="contain",
203
+ height="auto",
204
+ preview=True
205
  )
206
+ selected_video = gr.Video(label="Selected Video")
207
  refresh_btn = gr.Button("Refresh Gallery")
208
 
 
 
 
209
  refresh_btn.click(fn=update_gallery, inputs=None, outputs=gallery)
210
  demo.load(fn=update_gallery, inputs=None, outputs=gallery)
211
+ gallery.select(show_video, None, selected_video)
212
 
213
  input_text.submit(
214
  fn=respond,
215
+ inputs=[input_image, input_text, steps, cfg_scale, eta, fs, seed, video_length],
216
+ outputs=output_video
217
  ).then(
218
  fn=update_gallery,
219
  inputs=None,