ali-kanbar commited on
Commit
2333322
·
verified ·
1 Parent(s): 4a9fb82

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +239 -0
  2. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import asyncio
3
+ import os
4
+ import traceback
5
+ import numpy as np
6
+ import re
7
+ from functools import partial
8
+ import torch
9
+ import imageio
10
+ import cv2
11
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
12
+ from huggingface_hub import hf_hub_download
13
+ from safetensors.torch import load_file
14
+ from PIL import Image
15
+ import edge_tts
16
+ from transformers import AutoTokenizer, pipeline
17
+ from moviepy.editor import VideoFileClip, AudioFileClip
18
+ from func_timeout import func_timeout, FunctionTimedOut
19
+
20
+ # Initialize models with cache optimization
21
+ def initialize_components():
22
+ global tokenizer, text_pipe, sentiment_analyzer, pipe
23
+
24
+ # Text generation components
25
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct", cache_dir="model_cache")
26
+ text_pipe = pipeline(
27
+ "text-generation",
28
+ model="Qwen/Qwen2.5-1.5B-Instruct",
29
+ tokenizer=tokenizer,
30
+ device_map="auto",
31
+ cache_dir="model_cache"
32
+ )
33
+
34
+ # Sentiment analysis
35
+ sentiment_analyzer = pipeline("sentiment-analysis", cache_dir="model_cache")
36
+
37
+ # Video generation setup
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ dtype = torch.float16 if device == "cuda" else torch.float32
40
+ step = 8
41
+ repo = "ByteDance/AnimateDiff-Lightning"
42
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
43
+ base = "emilianJR/epiCRealism"
44
+
45
+ # Load motion adapter with caching
46
+ adapter = MotionAdapter().to(device, dtype)
47
+ model_path = hf_hub_download(repo, ckpt, cache_dir="model_cache")
48
+ adapter.load_state_dict(load_file(model_path, device=device))
49
+
50
+ # Initialize pipeline
51
+ pipe = AnimateDiffPipeline.from_pretrained(
52
+ base,
53
+ motion_adapter=adapter,
54
+ torch_dtype=dtype,
55
+ cache_dir="model_cache"
56
+ ).to(device)
57
+
58
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
59
+ pipe.scheduler.config,
60
+ timestep_spacing="trailing",
61
+ beta_schedule="linear"
62
+ )
63
+
64
+ initialize_components()
65
+
66
+ # Cleanup function for resource management
67
+ def cleanup():
68
+ torch.cuda.empty_cache()
69
+ for f in ["generated_video.mp4", "final_video_with_audio.mp4", "output.mp3"]:
70
+ if os.path.exists(f):
71
+ try:
72
+ os.remove(f)
73
+ except:
74
+ pass
75
+
76
+ # Story generation functions (keep your original functions but add timeout)
77
+ def generate_video(summary):
78
+ def crossfade_transition(frames1, frames2, transition_length=10):
79
+ blended_frames = []
80
+ frames1_np = [np.array(frame) for frame in frames1[-transition_length:]]
81
+ frames2_np = [np.array(frame) for frame in frames2[:transition_length]]
82
+ for i in range(transition_length):
83
+ alpha = i / transition_length
84
+ beta = 1.0 - alpha
85
+ blended = cv2.addWeighted(frames1_np[i], beta, frames2_np[i], alpha, 0)
86
+ blended_frames.append(Image.fromarray(blended))
87
+ return blended_frames
88
+
89
+ sentences = []
90
+ current_sentence = ""
91
+ for char in summary:
92
+ current_sentence += char
93
+ if char in {'.', '!', '?'}:
94
+ sentences.append(current_sentence.strip())
95
+ current_sentence = ""
96
+ sentences = [s.strip() for s in sentences if s.strip()]
97
+
98
+ output_dir = "generated_frames"
99
+ video_path = "generated_video.mp4"
100
+ os.makedirs(output_dir, exist_ok=True)
101
+
102
+ all_frames = []
103
+ previous_frames = None
104
+ transition_frames = 10
105
+ batch_size = 1
106
+
107
+ for i in range(0, len(sentences), batch_size):
108
+ batch_prompts = sentences[i : i + batch_size]
109
+ for idx, prompt in enumerate(batch_prompts):
110
+ try:
111
+ output = func_timeout(
112
+ 300, # 5 minute timeout per scene
113
+ pipe,
114
+ args=(prompt,),
115
+ kwargs={
116
+ 'guidance_scale': 1.0,
117
+ 'num_inference_steps': step,
118
+ 'width': 128, # Reduced resolution
119
+ 'height': 128
120
+ }
121
+ )
122
+ frames = output.frames[0]
123
+
124
+ if previous_frames is not None:
125
+ transition = crossfade_transition(previous_frames, frames, transition_frames)
126
+ all_frames.extend(transition)
127
+
128
+ all_frames.extend(frames)
129
+ previous_frames = frames
130
+
131
+ except FunctionTimedOut:
132
+ print(f"Timeout generating scene {i+idx+1}")
133
+ return None
134
+ except Exception as e:
135
+ print(f"Error generating scene: {str(e)}")
136
+ continue
137
+
138
+ imageio.mimsave(video_path, all_frames, fps=6) # Reduced FPS
139
+ return video_path
140
+
141
+ # Modified main processing function with enhanced error handling
142
+ def create_story_video(prompt, progress=gr.Progress()):
143
+ cleanup() # Clear previous runs
144
+
145
+ if not prompt or len(prompt.strip()) < 5:
146
+ return "Prompt too short (min 5 characters)", None, None
147
+ if len(prompt) > 500:
148
+ return "Prompt too long (max 500 characters)", None, None
149
+
150
+ try:
151
+ progress(0, desc="Starting story generation...")
152
+ story = generate_story(prompt)
153
+ progress(25, desc="Story generated")
154
+
155
+ progress(30, desc="Starting video generation...")
156
+ video_path = generate_video(story)
157
+ if not video_path:
158
+ return story, None, "Video generation failed"
159
+ progress(60, desc="Video rendered")
160
+
161
+ progress(65, desc="Creating audio summary...")
162
+ audio_summary = summary_of_summary(story, video_path)
163
+
164
+ progress(75, desc="Generating voiceover...")
165
+ try:
166
+ loop = asyncio.new_event_loop()
167
+ asyncio.set_event_loop(loop)
168
+ audio_file = loop.run_until_complete(
169
+ generate_audio_with_sentiment(audio_summary, sentiment_analyzer)
170
+ )
171
+ except Exception as e:
172
+ return story, None, f"Audio error: {str(e)}"
173
+
174
+ progress(90, desc="Finalizing video...")
175
+ output_path = 'final_video_with_audio.mp4'
176
+ combine_video_with_audio(video_path, audio_file, output_path)
177
+
178
+ return story, output_path, audio_summary
179
+
180
+ except Exception as e:
181
+ error_msg = f"Error: {str(e)}"
182
+ print(traceback.format_exc())
183
+ return error_msg, None, None
184
+
185
+ # Keep other functions (summarize, generate_story, etc.) unchanged from your original code
186
+ # ...
187
+
188
+ # Gradio interface setup with resource management
189
+ EXAMPLE_PROMPTS = [
190
+ "A nurse discovers an unusual pattern in patient symptoms.",
191
+ "A family finds a time capsule during home renovation.",
192
+ "A restaurant owner innovates to save their business.",
193
+ "Wildlife tracking reveals climate changes.",
194
+ "Community rebuilds after natural disaster."
195
+ ]
196
+
197
+ with gr.Blocks(title="AI Story Generator", theme=gr.themes.Soft()) as demo:
198
+ gr.Markdown("# 🎬 AI Story Video Generator")
199
+ gr.Markdown("Enter a short story idea (5-500 characters)")
200
+
201
+ with gr.Row():
202
+ prompt_input = gr.Textbox(
203
+ label="Story Idea",
204
+ placeholder="Example: A detective finds a hidden room...",
205
+ max_lines=2
206
+ )
207
+
208
+ gr.Examples(
209
+ examples=EXAMPLE_PROMPTS,
210
+ inputs=prompt_input,
211
+ label="Example Prompts"
212
+ )
213
+
214
+ with gr.Row():
215
+ generate_btn = gr.Button("Generate", variant="primary")
216
+ clear_btn = gr.Button("Clear", variant="secondary")
217
+
218
+ with gr.Tabs():
219
+ with gr.Tab("Results"):
220
+ video_output = gr.Video(label="Generated Video", interactive=False)
221
+ story_output = gr.Textbox(label="Full Story", lines=10)
222
+ audio_summary = gr.Textbox(label="Audio Summary", lines=3)
223
+
224
+ generate_btn.click(
225
+ fn=create_story_video,
226
+ inputs=prompt_input,
227
+ outputs=[story_output, video_output, audio_summary]
228
+ )
229
+
230
+ clear_btn.click(
231
+ fn=lambda: [None, None, None],
232
+ outputs=[story_output, video_output, audio_summary]
233
+ )
234
+
235
+ demo.load(fn=cleanup)
236
+ demo.unload(fn=cleanup)
237
+
238
+ if __name__ == "__main__":
239
+ demo.launch(server_port=7860, show_error=True)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.25.0
2
+ edge-tts==6.1.3
3
+ torch==2.3.0
4
+ torchvision==0.18.0
5
+ diffusers==0.28.2
6
+ transformers==4.41.0
7
+ imageio==2.34.0
8
+ opencv-python==4.9.0.80
9
+ moviepy==1.0.3
10
+ safetensors==0.4.2
11
+ huggingface-hub==0.23.0
12
+ numpy==1.26.4
13
+ Pillow==10.3.0
14
+ accelerate==0.30.0