Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| from transformers import AutoTokenizer | |
| from TTS.api import TTS | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
| from torchvision.io import write_video | |
| import os | |
| import groq | |
| import logging | |
| from pathlib import Path | |
| import cv2 | |
| from moviepy.editor import VideoFileClip, AudioFileClip, CompositeVideoClip | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class EnhancedContentGenerator: | |
| def __init__(self): | |
| # Check for API key | |
| self.api_key = os.getenv("GROQ_API_KEY") | |
| if not self.api_key: | |
| raise ValueError("GROQ_API_KEY not found in environment variables") | |
| self.output_dir = Path("generated_content") | |
| self.output_dir.mkdir(exist_ok=True) | |
| # Initialize TTS with a more cartoon-appropriate voice | |
| self.tts = TTS(model_name="tts_models/en/vctk/vits") | |
| # Initialize Stable Diffusion with cartoon-specific model | |
| self.pipe = StableDiffusionPipeline.from_pretrained( | |
| "nitrosocke/Ghibli-Diffusion", # Using anime/cartoon style model | |
| torch_dtype=torch.float32 | |
| ) | |
| self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config) | |
| self.pipe = self.pipe.to("cpu") | |
| self.pipe.enable_attention_slicing() | |
| # Initialize Groq client | |
| self.groq_client = groq.Groq(api_key=self.api_key) | |
| # Create output directories if they don't exist | |
| self.audio_dir = self.output_dir / "audio" | |
| self.video_dir = self.output_dir / "video" | |
| self.audio_dir.mkdir(exist_ok=True) | |
| self.video_dir.mkdir(exist_ok=True) | |
| def generate_cartoon_frame(self, prompt, style="cartoon"): | |
| """Generate a single cartoon frame with specified style""" | |
| style_prompts = { | |
| "cartoon": "in the style of a western cartoon, vibrant colors, simple shapes", | |
| "anime": "in the style of Studio Ghibli anime, detailed backgrounds", | |
| "kids": "in the style of a children's book illustration, cute and colorful" | |
| } | |
| enhanced_prompt = f"{prompt}, {style_prompts.get(style, style_prompts['cartoon'])}" | |
| with torch.no_grad(): | |
| image = self.pipe( | |
| enhanced_prompt, | |
| num_inference_steps=30, | |
| guidance_scale=7.5 | |
| ).images[0] | |
| return np.array(image) | |
| def add_cartoon_effects(self, frame): | |
| """Add cartoon-style effects to a frame""" | |
| # Convert to RGB if necessary | |
| if len(frame.shape) == 2: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) | |
| # Apply cartoon effect | |
| gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) | |
| gray = cv2.medianBlur(gray, 5) | |
| edges = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 9) | |
| color = cv2.bilateralFilter(frame, 9, 300, 300) | |
| # Combine edges with color | |
| cartoon = cv2.bitwise_and(color, color, mask=edges) | |
| return cartoon | |
| def generate_video_sequence(self, script, style="cartoon", num_frames=24): | |
| """Generate a sequence of frames based on the script""" | |
| frames = [] | |
| scenes = script.split('\n\n') # Split script into scenes | |
| frames_per_scene = max(num_frames // len(scenes), 4) | |
| for scene in scenes: | |
| if not scene.strip(): | |
| continue | |
| # Generate base frame for the scene | |
| scene_prompt = f"cartoon scene showing: {scene}" | |
| base_frame = self.generate_cartoon_frame(scene_prompt, style) | |
| # Generate slight variations for animation | |
| for i in range(frames_per_scene): | |
| frame = base_frame.copy() | |
| frame = self.add_cartoon_effects(frame) | |
| frames.append(frame) | |
| return frames | |
| def enhance_audio(self, audio_path, style="cartoon"): | |
| """Add effects to the audio based on style""" | |
| try: | |
| audio = AudioFileClip(str(audio_path)) | |
| if style == "cartoon": | |
| # Speed up slightly for cartoon effect | |
| audio = audio.speedx(1.1) | |
| elif style == "kids": | |
| # Add echo effect for kids music | |
| echo = audio.set_start(0.1) | |
| audio = CompositeVideoClip([audio, echo.volumex(0.3)]) | |
| enhanced_path = str(audio_path).replace('.wav', '_enhanced.wav') | |
| audio.write_audiofile(enhanced_path) | |
| return enhanced_path | |
| except Exception as e: | |
| logger.error(f"Error enhancing audio: {str(e)}") | |
| return str(audio_path) | |
| def generate_comedy_animation(self, prompt): | |
| """Generate enhanced comedy animation""" | |
| try: | |
| # Generate a more structured comedy script | |
| script_prompt = f"""Write a funny cartoon script about {prompt}. | |
| Include: | |
| - Two distinct character voices | |
| - Physical comedy moments | |
| - Sound effects in [brackets] | |
| - Scene descriptions in (parentheses) | |
| Keep it family-friendly and around 3-4 scenes.""" | |
| completion = self.groq_client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a professional cartoon comedy writer."}, | |
| {"role": "user", "content": script_prompt} | |
| ], | |
| model="mixtral-8x7b-32768", | |
| temperature=0.7 | |
| ) | |
| script = completion.choices[0].message.content | |
| # Generate frames with cartoon style | |
| frames = self.generate_video_sequence(script, style="cartoon") | |
| # Generate and enhance audio | |
| speech_path = self.audio_dir / f"speech_{hash(script)}.wav" | |
| self.tts.tts_to_file(text=script, file_path=str(speech_path)) | |
| enhanced_speech = self.enhance_audio(speech_path, "cartoon") | |
| # Create video with enhanced frames | |
| video_path = self.video_dir / f"video_{hash(prompt)}.mp4" | |
| frames_tensor = torch.from_numpy(np.array(frames)).permute(0, 3, 1, 2) | |
| write_video(str(video_path), frames_tensor, fps=12) # Higher FPS for smoother animation | |
| return script, str(video_path), enhanced_speech | |
| except Exception as e: | |
| logger.error(f"Error in comedy animation generation: {str(e)}") | |
| return "Error generating content", None, None | |
| def generate_kids_music_animation(self, theme): | |
| """Generate enhanced kids music animation""" | |
| try: | |
| # Generate kid-friendly lyrics with music directions | |
| lyrics_prompt = f"""Write lyrics for a children's educational song about {theme}. | |
| Include: | |
| - Simple, repetitive chorus | |
| - Educational facts | |
| - [Music notes] for melody changes | |
| - (Action descriptions) for animation | |
| Make it upbeat and memorable!""" | |
| completion = self.groq_client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a children's music composer."}, | |
| {"role": "user", "content": lyrics_prompt} | |
| ], | |
| model="mixtral-8x7b-32768", | |
| temperature=0.7 | |
| ) | |
| lyrics = completion.choices[0].message.content | |
| # Generate frames with kids' style | |
| frames = self.generate_video_sequence(lyrics, style="kids", num_frames=36) | |
| # Generate and enhance audio | |
| speech_path = self.audio_dir / f"music_{hash(lyrics)}.wav" | |
| self.tts.tts_to_file(text=lyrics, file_path=str(speech_path)) | |
| enhanced_speech = self.enhance_audio(speech_path, "kids") | |
| # Create video with enhanced frames | |
| video_path = self.video_dir / f"video_{hash(theme)}.mp4" | |
| frames_tensor = torch.from_numpy(np.array(frames)).permute(0, 3, 1, 2) | |
| write_video(str(video_path), frames_tensor, fps=15) # Smooth animation for kids | |
| return lyrics, str(video_path), enhanced_speech | |
| except Exception as e: | |
| logger.error(f"Error in kids music animation generation: {str(e)}") | |
| return "Error generating content", None, None | |
| # Gradio Interface | |
| def create_interface(): | |
| generator = EnhancedContentGenerator() | |
| with gr.Blocks(theme='ysharma/steampunk') as app: | |
| gr.Markdown("# AI Cartoon Generator") | |
| gr.Markdown("Generate cartoon comedy clips and kids music videos!") | |
| with gr.Tab("Cartoon Comedy"): | |
| comedy_prompt = gr.Textbox( | |
| label="What should the cartoon be about?", | |
| placeholder="E.g., 'a penguin learning to fly'" | |
| ) | |
| comedy_generate_btn = gr.Button("Generate Cartoon Comedy", variant="primary") | |
| comedy_script = gr.Textbox(label="Generated Script") | |
| comedy_animation = gr.Video(label="Cartoon Animation") | |
| comedy_audio = gr.Audio(label="Cartoon Audio") | |
| with gr.Tab("Kids Music Video"): | |
| music_theme = gr.Textbox( | |
| label="What should the song teach about?", | |
| placeholder="E.g., 'the water cycle'" | |
| ) | |
| music_generate_btn = gr.Button("Generate Music Video", variant="primary") | |
| music_lyrics = gr.Textbox(label="Song Lyrics") | |
| music_animation = gr.Video(label="Music Video") | |
| music_audio = gr.Audio(label="Song Audio") | |
| # Event handlers | |
| comedy_generate_btn.click( | |
| generator.generate_comedy_animation, | |
| inputs=comedy_prompt, | |
| outputs=[comedy_script, comedy_animation, comedy_audio] | |
| ) | |
| music_generate_btn.click( | |
| generator.generate_kids_music_animation, | |
| inputs=music_theme, | |
| outputs=[music_lyrics, music_animation, music_audio] | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_interface() | |
| app.launch() |