videoGen / app.py
Dhan98's picture
Update app.py
685b08e verified
raw
history blame
4.52 kB
# app.py
import streamlit as st
from transformers import BlipProcessor, BlipForConditionalGeneration
from diffusers import DiffusionPipeline
import torch
import cv2
import numpy as np
from PIL import Image
import tempfile
import os
# Configure page
st.set_page_config(
page_title="Video Generator",
page_icon="πŸŽ₯",
layout="wide"
)
@st.cache_resource
def load_models():
# Load text-to-video model
pipeline = DiffusionPipeline.from_pretrained(
"cerspense/zeroscope_v2_576w",
torch_dtype=torch.float16
)
if torch.cuda.is_available():
pipeline.to("cuda")
else:
pipeline.to("cpu")
# Load image captioning model
blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
if torch.cuda.is_available():
blip.to("cuda")
else:
blip.to("cpu")
return pipeline, blip, blip_processor
def enhance_image(image):
# Convert PIL Image to numpy array
img_array = np.array(image)
# Basic enhancement: Increase contrast and brightness
enhanced = cv2.convertScaleAbs(img_array, alpha=1.2, beta=10)
return Image.fromarray(enhanced)
def get_description(image, blip, blip_processor):
# Process image for BLIP
inputs = blip_processor(image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Generate caption
with torch.no_grad():
generated_ids = blip.generate(pixel_values=inputs["pixel_values"], max_length=50)
description = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
return description
def generate_video(pipeline, description):
# Generate video frames
video_frames = pipeline(
description,
num_inference_steps=30,
num_frames=16
).frames
# Create temporary directory and file path
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "output.mp4")
# Convert frames to video
height, width = video_frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(temp_path, fourcc, 8, (width, height))
for frame in video_frames:
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
video_writer.release()
return temp_path
def main():
st.title("πŸŽ₯ AI Video Generator")
st.write("Upload an image to generate a video based on its content!")
try:
# Load models
pipeline, blip, blip_processor = load_models()
# File uploader
image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg'])
if image_file:
# Display original and enhanced image
col1, col2 = st.columns(2)
with col1:
image = Image.open(image_file)
st.image(image, caption="Original Image")
with col2:
enhanced_image = enhance_image(image)
st.image(enhanced_image, caption="Enhanced Image")
# Get and display description
description = get_description(enhanced_image, blip, blip_processor)
st.write("πŸ“ Generated Description:", description)
# Allow user to edit description
modified_description = st.text_area("Edit description if needed:", description)
# Generate video button
if st.button("🎬 Generate Video"):
with st.spinner("Generating video... This may take a few minutes."):
video_path = generate_video(pipeline, modified_description)
st.success("Video generated successfully!")
st.video(video_path)
# Add download button
with open(video_path, 'rb') as f:
st.download_button(
label="Download Video",
data=f,
file_name="generated_video.mp4",
mime="video/mp4"
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.error("Please try again or contact support if the error persists.")
if __name__ == "__main__":
main()