Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import json | |
| import time | |
| import random | |
| import subprocess | |
| from pathlib import Path | |
| from typing import List, Any | |
| import google.generativeai as genai | |
| from tavily import TavilyClient | |
| from runwayml import RunwayML, TaskFailedError | |
| from PIL import Image, ImageDraw, ImageFont | |
| # ============================================================= | |
| # AI VIDEO STUDIO (Gen-4 Turbo Image→Video) – Robust Version | |
| # ============================================================= | |
| # Improvements in this revision: | |
| # - Normalizes narration if model returns list (was causing list.split() AttributeError). | |
| # - Defensive checks & type coercion for scene_prompts. | |
| # - Safer JSON extraction (optionally attempts a JSON substring if extra text present). | |
| # - Fixed accidental newline handling for 'facts'. | |
| # - Added explicit JSON enforcement hint to Gemini. | |
| # - Added helper to truncate overly long narration. | |
| # - Added more granular progress steps & logging. | |
| # - Added retry for Gemini (transient failures) and for Runway polling. | |
| # - Added validate_scene_prompts() to guarantee list[str] length == SCENE_COUNT. | |
| # ============================================================= | |
| # --- 1. CONFIGURE API KEYS --- | |
| try: | |
| genai.configure(api_key=os.environ["GEMINI_API_KEY"]) | |
| tavily_client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) | |
| RUNWAY_API_KEY = os.environ["RUNWAY_API_KEY"] | |
| runway_client = RunwayML(api_key=RUNWAY_API_KEY) | |
| except KeyError as e: | |
| raise ValueError(f"API Key Error: Please set the {e} secret in your environment.") | |
| # --- 2. CONSTANTS / SETTINGS --- | |
| GEN4_MODEL = "gen4_turbo" # adjust to "gen4" for non‑turbo | |
| SCENE_COUNT = 4 | |
| SCENE_DURATION_SECONDS = 5 # 5 or 10 supported | |
| VIDEO_RATIO = "1280:720" # 16:9 | |
| WORDS_PER_SEC = 2.5 | |
| MAX_POLL_SECONDS = 180 # per scene | |
| POLL_INTERVAL = 5 | |
| GEMINI_MAX_RETRIES = 2 | |
| MAX_NARRATION_WORDS = 520 # safeguard length | |
| # --- 3. UTILITIES --- | |
| def _log(msg: str): | |
| print(f"[AI-STUDIO] {msg}") | |
| def create_placeholder_image(text: str, path: Path, size=(1280, 720)) -> Path: | |
| """Create a simple placeholder keyframe if user supplies none.""" | |
| img = Image.new("RGB", size, (10, 10, 10)) | |
| draw = ImageDraw.Draw(img) | |
| try: | |
| font = ImageFont.truetype("DejaVuSans-Bold.ttf", 60) | |
| except Exception: | |
| font = ImageFont.load_default() | |
| # naive wrap | |
| words = text.split() | |
| wrapped: List[str] = [] | |
| line = "" | |
| for word in words: | |
| test = f"{line} {word}".strip() | |
| if len(test) > 28: | |
| if line: | |
| wrapped.append(line) | |
| line = word | |
| else: | |
| line = test | |
| if line: | |
| wrapped.append(line) | |
| y = size[1] // 2 - (len(wrapped) * 35) // 2 | |
| for w in wrapped: | |
| w_width, w_height = draw.textsize(w, font=font) | |
| draw.text(((size[0]-w_width)//2, y), w, fill=(240, 240, 240), font=font) | |
| y += w_height + 10 | |
| img.save(path) | |
| return path | |
| def generate_mock_voiceover(narration: str, out_path: Path): | |
| duration = len(narration.split()) / WORDS_PER_SEC | |
| subprocess.run([ | |
| 'ffmpeg', '-f', 'lavfi', '-i', 'anullsrc=r=44100:cl=mono', | |
| '-t', str(duration), '-q:a', '9', '-acodec', 'libmp3lame', str(out_path), '-y' | |
| ], check=True) | |
| return duration | |
| def poll_runway_task(task_obj, max_seconds=MAX_POLL_SECONDS, interval=POLL_INTERVAL): | |
| start = time.time() | |
| while True: | |
| task_obj.refresh() | |
| status = task_obj.status | |
| if status == 'SUCCEEDED': | |
| return task_obj | |
| if status == 'FAILED': | |
| raise TaskFailedError(task_details=task_obj) | |
| if time.time() - start > max_seconds: | |
| raise TimeoutError(f"Runway task timed out after {max_seconds}s (status={status})") | |
| time.sleep(interval) | |
| def extract_json_block(text: str) -> str: | |
| """Attempt to isolate a JSON object in a noisy response.""" | |
| first = text.find('{') | |
| last = text.rfind('}') | |
| if first != -1 and last != -1 and last > first: | |
| candidate = text[first:last+1] | |
| return candidate | |
| return text | |
| def coerce_narration(narr: Any) -> str: | |
| if isinstance(narr, list): | |
| narr = ' '.join(str(x) for x in narr) | |
| if not isinstance(narr, str): | |
| narr = str(narr) | |
| words = narr.split() | |
| if len(words) > MAX_NARRATION_WORDS: | |
| narr = ' '.join(words[:MAX_NARRATION_WORDS]) | |
| return narr.strip() | |
| def validate_scene_prompts(sp: Any) -> List[str]: | |
| if not isinstance(sp, list): | |
| sp = [sp] | |
| flat: List[str] = [] | |
| for item in sp: | |
| if isinstance(item, list): | |
| flat.extend(str(x) for x in item) | |
| else: | |
| flat.append(str(item)) | |
| # Trim or pad | |
| if len(flat) < SCENE_COUNT: | |
| flat.extend([flat[-1]] * (SCENE_COUNT - len(flat))) | |
| if len(flat) > SCENE_COUNT: | |
| flat = flat[:SCENE_COUNT] | |
| return [s.strip() for s in flat] | |
| def call_gemini_script(topic: str, facts: str) -> tuple[str, List[str]]: | |
| gemini_model = genai.GenerativeModel('gemini-1.5-flash') | |
| script_prompt = f""" | |
| You are a creative director for viral short-form educational videos. | |
| Topic: {topic} | |
| Research (may contain noise): | |
| {facts} | |
| STRICT JSON OUTPUT ONLY. Do not add commentary or markdown fences. | |
| Schema: {{"narration_script": string, "scene_prompts": list[{SCENE_COUNT}]}} | |
| narration_script rules: energetic, cohesive, <= {MAX_NARRATION_WORDS} words total, no scene numbers. | |
| scene_prompts: exactly {SCENE_COUNT} cinematic visual descriptions (1-2 sentences each) including style, camera, lighting. | |
| Return JSON ONLY. | |
| """ | |
| last_error = None | |
| for attempt in range(GEMINI_MAX_RETRIES): | |
| try: | |
| response = gemini_model.generate_content(script_prompt) | |
| raw = response.text.strip() | |
| raw = raw.replace('```json', '').replace('```', '').strip() | |
| raw = extract_json_block(raw) | |
| data = json.loads(raw) | |
| narration = coerce_narration(data.get('narration_script', '')) | |
| scene_prompts = validate_scene_prompts(data.get('scene_prompts', [])) | |
| return narration, scene_prompts | |
| except Exception as e: | |
| last_error = e | |
| time.sleep(1 + attempt) | |
| raise ValueError(f"Gemini JSON parse failed after {GEMINI_MAX_RETRIES} attempts: {last_error}") | |
| # --- 4. CORE PIPELINE --- | |
| def generate_video_from_topic(topic_prompt, keyframe_image, progress=gr.Progress(track_tqdm=True)): | |
| job_id = f"{int(time.time())}_{random.randint(1000, 9999)}" | |
| _log(f"Starting job {job_id} :: topic='{topic_prompt}'") | |
| workdir = Path(f"job_{job_id}") | |
| workdir.mkdir(exist_ok=True) | |
| intermediates = [] | |
| try: | |
| # STEP 1: Research | |
| progress(0.05, desc="🔍 Researching topic ...") | |
| facts = "No research data available." | |
| try: | |
| research_results = tavily_client.search( | |
| query=f"Key facts and interesting points about {topic_prompt}", | |
| search_depth="basic" | |
| ) | |
| if research_results and 'results' in research_results: | |
| facts = " | |
| ".join(res.get('content', '') for res in research_results['results']) | |
| except Exception as e: | |
| _log(f"Tavily failed: {e}") | |
| # STEP 2: Script | |
| progress(0.15, desc="✍️ Writing script ...") | |
| narration, scene_prompts = call_gemini_script(topic_prompt, facts) | |
| _log(f"Narration words: {len(narration.split())}; scenes: {len(scene_prompts)}") | |
| # STEP 3: Mock VO | |
| progress(0.25, desc="🎙️ Generating mock VO ...") | |
| audio_path = workdir / f"narration_{job_id}.mp3" | |
| generate_mock_voiceover(narration, audio_path) | |
| intermediates.append(audio_path) | |
| # STEP 4: Keyframe image (required for Gen-4 image_to_video) | |
| progress(0.30, desc="🖼️ Preparing keyframe image ...") | |
| if keyframe_image is not None: | |
| keyframe_path = Path(keyframe_image) | |
| else: | |
| keyframe_path = workdir / "auto_keyframe.png" | |
| create_placeholder_image(topic_prompt, keyframe_path) | |
| intermediates.append(keyframe_path) | |
| # STEP 5: Generate scenes | |
| clip_paths: List[Path] = [] | |
| for idx, scene_prompt in enumerate(scene_prompts, start=1): | |
| base_progress = 0.30 + (idx * 0.12) | |
| progress(min(base_progress, 0.85), desc=f"🎬 Scene {idx}/{len(scene_prompts)} ...") | |
| _log(f"Submitting scene {idx}: {scene_prompt[:100]} ...") | |
| try: | |
| task = runway_client.image_to_video.create( | |
| model=GEN4_MODEL, | |
| prompt_image=str(keyframe_path), # required | |
| prompt_text=scene_prompt, | |
| duration=SCENE_DURATION_SECONDS, | |
| ratio=VIDEO_RATIO, | |
| ) | |
| task = poll_runway_task(task) | |
| video_url = task.output[0] | |
| except TaskFailedError as e: | |
| raise gr.Error(f"Runway failed scene {idx}: {getattr(e, 'task_details', 'No details')}") | |
| clip_path = workdir / f"scene_{idx}.mp4" | |
| r = runway_client._session.get(video_url, stream=True) | |
| with open(clip_path, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| clip_paths.append(clip_path) | |
| intermediates.append(clip_path) | |
| _log(f"Downloaded scene {idx} -> {clip_path}") | |
| # STEP 6: Concatenate video | |
| progress(0.90, desc="✂️ Concatenating scenes ...") | |
| list_file = workdir / "clips.txt" | |
| with open(list_file, 'w') as lf: | |
| for p in clip_paths: | |
| lf.write(f"file '{p}' | |
| ") | |
| intermediates.append(list_file) | |
| concat_path = workdir / f"concat_{job_id}.mp4" | |
| subprocess.run([ | |
| 'ffmpeg', '-f', 'concat', '-safe', '0', '-i', str(list_file), '-c', 'copy', str(concat_path), '-y' | |
| ], check=True) | |
| intermediates.append(concat_path) | |
| # STEP 7: Mux audio | |
| final_path = workdir / f"final_{job_id}.mp4" | |
| progress(0.95, desc="🔊 Merging audio ...") | |
| subprocess.run([ | |
| 'ffmpeg', '-i', str(concat_path), '-i', str(audio_path), '-c:v', 'copy', '-c:a', 'aac', '-shortest', str(final_path), '-y' | |
| ], check=True) | |
| progress(1.0, desc="✅ Done") | |
| _log(f"FINAL VIDEO: {final_path}") | |
| return str(final_path) | |
| except Exception as e: | |
| _log(f"JOB {job_id} FAILED: {e}") | |
| raise gr.Error(f"An error occurred: {e}") | |
| finally: | |
| # Keep workdir for debugging; remove manually when satisfied. | |
| pass | |
| # --- 5. GRADIO UI --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🤖 My Personal AI Video Studio (Gen-4 Turbo)") | |
| gr.Markdown("Enter a topic and (optionally) upload a keyframe image. Without an image, a simple placeholder is generated.") | |
| with gr.Row(): | |
| topic_input = gr.Textbox(label="Video Topic", placeholder="e.g., 'The history of coffee'", scale=3) | |
| image_input = gr.Image(label="Keyframe Image (optional)", type="filepath") | |
| with gr.Row(): | |
| generate_button = gr.Button("Generate Video", variant="primary") | |
| with gr.Row(): | |
| video_output = gr.Video(label="Generated Video") | |
| generate_button.click( | |
| fn=generate_video_from_topic, | |
| inputs=[topic_input, image_input], | |
| outputs=video_output | |
| ) | |
| gr.Markdown("--- | |
| ### Tips | |
| - Supply a consistent character/style image for more coherent scenes. | |
| - Gen-4 requires an input image + (optional) text prompt; pure text alone is not supported in this flow. | |
| - For pure text-to-video consider a Gen-3 text model. | |
| - Replace placeholder keyframe logic with a real T2I model for higher quality.") | |
| if __name__ == "__main__": | |
| demo.launch() | |