|
|
|
import streamlit as st |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import tempfile |
|
import os |
|
|
|
|
|
st.set_page_config( |
|
page_title="Video Generator", |
|
page_icon="π₯", |
|
layout="wide" |
|
) |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
|
|
pipeline = DiffusionPipeline.from_pretrained( |
|
"cerspense/zeroscope_v2_576w", |
|
torch_dtype=torch.float16 |
|
) |
|
if torch.cuda.is_available(): |
|
pipeline.to("cuda") |
|
else: |
|
pipeline.to("cpu") |
|
|
|
|
|
blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
|
if torch.cuda.is_available(): |
|
blip.to("cuda") |
|
else: |
|
blip.to("cpu") |
|
|
|
return pipeline, blip, blip_processor |
|
|
|
def enhance_image(image): |
|
|
|
img_array = np.array(image) |
|
|
|
|
|
enhanced = cv2.convertScaleAbs(img_array, alpha=1.2, beta=10) |
|
|
|
return Image.fromarray(enhanced) |
|
|
|
def get_description(image, blip, blip_processor): |
|
|
|
inputs = blip_processor(image, return_tensors="pt") |
|
|
|
if torch.cuda.is_available(): |
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = blip.generate(pixel_values=inputs["pixel_values"], max_length=50) |
|
description = blip_processor.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
return description |
|
|
|
def generate_video(pipeline, description): |
|
|
|
video_frames = pipeline( |
|
description, |
|
num_inference_steps=30, |
|
num_frames=16 |
|
).frames |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
temp_path = os.path.join(temp_dir, "output.mp4") |
|
|
|
|
|
height, width = video_frames[0].shape[:2] |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
video_writer = cv2.VideoWriter(temp_path, fourcc, 8, (width, height)) |
|
|
|
for frame in video_frames: |
|
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) |
|
|
|
video_writer.release() |
|
|
|
return temp_path |
|
|
|
def main(): |
|
st.title("π₯ AI Video Generator") |
|
st.write("Upload an image to generate a video based on its content!") |
|
|
|
try: |
|
|
|
pipeline, blip, blip_processor = load_models() |
|
|
|
|
|
image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg']) |
|
|
|
if image_file: |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
image = Image.open(image_file) |
|
st.image(image, caption="Original Image") |
|
|
|
with col2: |
|
enhanced_image = enhance_image(image) |
|
st.image(enhanced_image, caption="Enhanced Image") |
|
|
|
|
|
description = get_description(enhanced_image, blip, blip_processor) |
|
st.write("π Generated Description:", description) |
|
|
|
|
|
modified_description = st.text_area("Edit description if needed:", description) |
|
|
|
|
|
if st.button("π¬ Generate Video"): |
|
with st.spinner("Generating video... This may take a few minutes."): |
|
video_path = generate_video(pipeline, modified_description) |
|
st.success("Video generated successfully!") |
|
st.video(video_path) |
|
|
|
|
|
with open(video_path, 'rb') as f: |
|
st.download_button( |
|
label="Download Video", |
|
data=f, |
|
file_name="generated_video.mp4", |
|
mime="video/mp4" |
|
) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
st.error("Please try again or contact support if the error persists.") |
|
|
|
if __name__ == "__main__": |
|
main() |