kevalfst commited on
Commit
76f81b8
·
verified ·
1 Parent(s): 7fae245

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -23
app.py CHANGED
@@ -1,14 +1,23 @@
1
  import gradio as gr
2
  import torch
3
  import random
 
4
  from diffusers import DiffusionPipeline
5
  from transformers import pipeline
 
 
 
 
 
 
 
 
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
9
  MAX_SEED = 2**32 - 1
10
 
11
- # --- Model lists ordered by size (light to heavy) ---
12
  image_models = {
13
  "Stable Diffusion 1.5 (light)": "runwayml/stable-diffusion-v1-5",
14
  "Stable Diffusion 2.1": "stabilityai/stable-diffusion-2-1",
@@ -35,46 +44,136 @@ text_models = {
35
  "LLaMA 2 7B (heavy)": "meta-llama/Llama-2-7b-hf"
36
  }
37
 
38
- # Cache
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  image_pipes = {}
40
  text_pipes = {}
 
 
 
 
 
 
 
 
41
 
42
  def generate_image(prompt, model_name, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
43
  if randomize_seed:
44
  seed = random.randint(0, MAX_SEED)
45
- generator = torch.manual_seed(seed)
46
 
47
- progress(0, desc="Loading model...")
 
 
 
 
 
48
  if model_name not in image_pipes:
49
- image_pipes[model_name] = DiffusionPipeline.from_pretrained(
50
  image_models[model_name],
51
- torch_dtype=torch_dtype
52
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  pipe = image_pipes[model_name]
54
 
55
- progress(25, desc="Running inference (step 1/3)...")
56
- result = pipe(prompt=prompt, generator=generator, num_inference_steps=30, width=512, height=512)
 
 
57
 
58
  progress(100, desc="Done.")
59
- return result.images[0], seed
60
 
61
  def generate_text(prompt, model_name, progress=gr.Progress(track_tqdm=True)):
62
- progress(0, desc="Loading model...")
 
 
 
 
 
63
  if model_name not in text_pipes:
64
- text_pipes[model_name] = pipeline("text-generation", model=text_models[model_name], device=0 if device == "cuda" else -1)
 
 
 
 
65
  pipe = text_pipes[model_name]
66
 
67
- progress(50, desc="Generating text...")
68
  result = pipe(prompt, max_length=100, do_sample=True)[0]['generated_text']
 
 
69
  progress(100, desc="Done.")
70
  return result
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # Gradio Interface
73
  with gr.Blocks() as demo:
74
- gr.Markdown("# 🧠 Visionary AI")
75
 
76
  with gr.Tabs():
77
- # 🖼️ Image Gen Tab
78
  with gr.Tab("🖼️ Image Generation"):
79
  img_prompt = gr.Textbox(label="Prompt")
80
  img_model = gr.Dropdown(choices=list(image_models.keys()), value="Stable Diffusion 1.5 (light)", label="Image Model")
@@ -84,20 +183,22 @@ with gr.Blocks() as demo:
84
  img_out = gr.Image()
85
  img_btn.click(fn=generate_image, inputs=[img_prompt, img_model, img_seed, img_rand], outputs=[img_out, img_seed])
86
 
87
- # 📝 Text Gen Tab
88
  with gr.Tab("📝 Text Generation"):
89
  txt_prompt = gr.Textbox(label="Prompt")
90
  txt_model = gr.Dropdown(choices=list(text_models.keys()), value="GPT-2 (light)", label="Text Model")
91
  txt_btn = gr.Button("Generate Text")
92
  txt_out = gr.Textbox(label="Output Text")
93
- txt_btn.click(fn=generate_text, inputs=[txt_prompt, txt_model], outputs=txt_out)
94
 
95
- # 🎥 Video Gen Tab (placeholder)
96
- with gr.Tab("🎥 Video Generation (Coming Soon)"):
97
- gr.Markdown("⚠️ Video generation is placeholder only. Models require special setup.")
98
  vid_prompt = gr.Textbox(label="Prompt")
99
- vid_btn = gr.Button("Pretend to Generate")
100
- vid_out = gr.Textbox(label="Result")
101
- vid_btn.click(lambda x: f"Fake video output for: {x}", inputs=[vid_prompt], outputs=[vid_out])
 
 
 
102
 
103
  demo.launch(show_error=True)
 
1
  import gradio as gr
2
  import torch
3
  import random
4
+ import hashlib
5
  from diffusers import DiffusionPipeline
6
  from transformers import pipeline
7
+ from diffusers.utils import export_to_video
8
+
9
+ # Optional: xformers optimization
10
+ try:
11
+ import xformers
12
+ has_xformers = True
13
+ except ImportError:
14
+ has_xformers = False
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
  MAX_SEED = 2**32 - 1
19
 
20
+ # Model lists ordered by size
21
  image_models = {
22
  "Stable Diffusion 1.5 (light)": "runwayml/stable-diffusion-v1-5",
23
  "Stable Diffusion 2.1": "stabilityai/stable-diffusion-2-1",
 
44
  "LLaMA 2 7B (heavy)": "meta-llama/Llama-2-7b-hf"
45
  }
46
 
47
+ video_models = {
48
+ "CogVideoX-2B": "THUDM/CogVideoX-2b",
49
+ "CogVideoX-5B": "THUDM/CogVideoX-5b",
50
+ "AnimateDiff-Lightning": "ByteDance/AnimateDiff-Lightning",
51
+ "ModelScope T2V": "damo-vilab/text-to-video-ms-1.7b",
52
+ "VideoCrafter2": "VideoCrafter/VideoCrafter2",
53
+ "Open-Sora-Plan-v1.2.0": "LanguageBind/Open-Sora-Plan-v1.2.0",
54
+ "LTX-Video": "Lightricks/LTX-Video",
55
+ "HunyuanVideo": "tencent/HunyuanVideo",
56
+ "Latte-1": "maxin-cn/Latte-1",
57
+ "LaVie": "Vchitect/LaVie"
58
+ }
59
+
60
+ # Caches
61
  image_pipes = {}
62
  text_pipes = {}
63
+ video_pipes = {}
64
+ image_cache = {}
65
+ text_cache = {}
66
+ video_cache = {}
67
+
68
+ def hash_inputs(*args):
69
+ combined = "|".join(map(str, args))
70
+ return hashlib.sha256(combined.encode()).hexdigest()
71
 
72
  def generate_image(prompt, model_name, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
73
  if randomize_seed:
74
  seed = random.randint(0, MAX_SEED)
 
75
 
76
+ key = hash_inputs(prompt, model_name, seed)
77
+ if key in image_cache:
78
+ progress(100, desc="Using cached image.")
79
+ return image_cache[key], seed
80
+
81
+ progress(10, desc="Loading model...")
82
  if model_name not in image_pipes:
83
+ pipe = DiffusionPipeline.from_pretrained(
84
  image_models[model_name],
85
+ torch_dtype=torch_dtype,
86
+ low_cpu_mem_usage=True
87
+ )
88
+
89
+ if torch.__version__.startswith("2"):
90
+ pipe = torch.compile(pipe)
91
+ if has_xformers and device == "cuda":
92
+ try:
93
+ pipe.enable_xformers_memory_efficient_attention()
94
+ except Exception:
95
+ pass
96
+
97
+ pipe.to(device)
98
+ image_pipes[model_name] = pipe
99
+
100
  pipe = image_pipes[model_name]
101
 
102
+ progress(40, desc="Generating image...")
103
+ result = pipe(prompt=prompt, generator=torch.manual_seed(seed), num_inference_steps=15, width=512, height=512)
104
+ image = result.images[0]
105
+ image_cache[key] = image
106
 
107
  progress(100, desc="Done.")
108
+ return image, seed
109
 
110
  def generate_text(prompt, model_name, progress=gr.Progress(track_tqdm=True)):
111
+ key = hash_inputs(prompt, model_name)
112
+ if key in text_cache:
113
+ progress(100, desc="Using cached text.")
114
+ return text_cache[key]
115
+
116
+ progress(10, desc="Loading model...")
117
  if model_name not in text_pipes:
118
+ text_pipes[model_name] = pipeline(
119
+ "text-generation",
120
+ model=text_models[model_name],
121
+ device=0 if device == "cuda" else -1
122
+ )
123
  pipe = text_pipes[model_name]
124
 
125
+ progress(40, desc="Generating text...")
126
  result = pipe(prompt, max_length=100, do_sample=True)[0]['generated_text']
127
+ text_cache[key] = result
128
+
129
  progress(100, desc="Done.")
130
  return result
131
 
132
+ def generate_video(prompt, model_name, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
133
+ if randomize_seed:
134
+ seed = random.randint(0, MAX_SEED)
135
+
136
+ key = hash_inputs(prompt, model_name, seed)
137
+ if key in video_cache:
138
+ progress(100, desc="Using cached video.")
139
+ return video_cache[key], seed
140
+
141
+ progress(10, desc="Loading model...")
142
+ if model_name not in video_pipes:
143
+ pipe = DiffusionPipeline.from_pretrained(
144
+ video_models[model_name],
145
+ torch_dtype=torch_dtype,
146
+ variant="fp16"
147
+ )
148
+
149
+ if torch.__version__.startswith("2"):
150
+ pipe = torch.compile(pipe)
151
+ if has_xformers and device == "cuda":
152
+ try:
153
+ pipe.enable_xformers_memory_efficient_attention()
154
+ except Exception:
155
+ pass
156
+
157
+ pipe.to(device)
158
+ video_pipes[model_name] = pipe
159
+
160
+ pipe = video_pipes[model_name]
161
+
162
+ progress(40, desc="Generating video...")
163
+ result = pipe(prompt=prompt, generator=torch.manual_seed(seed), num_inference_steps=15)
164
+ video_frames = result.frames[0]
165
+ video_path = export_to_video(video_frames)
166
+ video_cache[key] = video_path
167
+
168
+ progress(100, desc="Done.")
169
+ return video_path, seed
170
+
171
  # Gradio Interface
172
  with gr.Blocks() as demo:
173
+ gr.Markdown("# Fast Multi-Model AI Playground with Caching")
174
 
175
  with gr.Tabs():
176
+ # Image Generation
177
  with gr.Tab("🖼️ Image Generation"):
178
  img_prompt = gr.Textbox(label="Prompt")
179
  img_model = gr.Dropdown(choices=list(image_models.keys()), value="Stable Diffusion 1.5 (light)", label="Image Model")
 
183
  img_out = gr.Image()
184
  img_btn.click(fn=generate_image, inputs=[img_prompt, img_model, img_seed, img_rand], outputs=[img_out, img_seed])
185
 
186
+ # Text Generation
187
  with gr.Tab("📝 Text Generation"):
188
  txt_prompt = gr.Textbox(label="Prompt")
189
  txt_model = gr.Dropdown(choices=list(text_models.keys()), value="GPT-2 (light)", label="Text Model")
190
  txt_btn = gr.Button("Generate Text")
191
  txt_out = gr.Textbox(label="Output Text")
192
+ txt_btn.click(fn=generate_text, inputs=[txt_prompt, txt_model], outputs=[txt_out])
193
 
194
+ # Video Generation
195
+ with gr.Tab("🎥 Video Generation"):
 
196
  vid_prompt = gr.Textbox(label="Prompt")
197
+ vid_model = gr.Dropdown(choices=list(video_models.keys()), value="CogVideoX-2B", label="Video Model")
198
+ vid_seed = gr.Slider(0, MAX_SEED, value=42, label="Seed")
199
+ vid_rand = gr.Checkbox(label="Randomize seed", value=True)
200
+ vid_btn = gr.Button("Generate Video")
201
+ vid_out = gr.Video()
202
+ vid_btn.click(fn=generate_video, inputs=[vid_prompt, vid_model, vid_seed, vid_rand], outputs=[vid_out, vid_seed])
203
 
204
  demo.launch(show_error=True)