SoraWithAzure / src /streamlit_app.py
levalencia's picture
first commit
6d01f39
raw
history blame
3.35 kB
"""
Streamlit app for generating videos using Azure Sora API.
- Users provide an API key, endpoint, and a text prompt.
- Advanced settings allow selection of video resolution, length, and number of variants.
- Generated videos are stored and displayed for all users.
"""
import streamlit as st
import os
from sora_video_downloader import SoraClient, VideoJob, VideoStorage
# --- Hardcoded for testing, but can be made user-editable ---
DEFAULT_API_KEY = os.getenv('AZURE_OPENAI_API_KEY', 'YOUR_AZURE_API_KEY')
DEFAULT_ENDPOINT = os.getenv('AZURE_OPENAI_ENDPOINT', 'https://levm3-me7f7pgq-eastus2.cognitiveservices.azure.com')
# --- UI: Title and Sidebar ---
st.title("Sora Video Generator (Azure)")
st.sidebar.header("Azure Sora Settings")
api_key = st.sidebar.text_input("API Key", value=DEFAULT_API_KEY, type="password")
endpoint = st.sidebar.text_input("Azure AI Foundry Endpoint", value=DEFAULT_ENDPOINT)
# --- UI: Main Input ---
st.header("Generate a Video with Sora")
prompt = st.text_area("Video Prompt", "A video of a cat playing with a ball of yarn in a sunny room")
# --- UI: Advanced Settings ---
st.subheader("Advanced Settings")
col1, col2 = st.columns(2)
DURATION_RES_MAP = {
5: [
(480, 480), (854, 480), (720, 720), (1280, 720), (1080, 1080), (1920, 1080)
],
10: [
(480, 480), (854, 480), (720, 720), (1280, 720), (1080, 1080)
],
20: [
(480, 480), (854, 480), (720, 720), (1280, 720)
]
}
with col1:
n_seconds = st.selectbox("Video Length (seconds)", options=[5, 10, 20], index=0)
valid_resolutions = DURATION_RES_MAP[n_seconds]
res_labels = [f"{w}x{h}" for (w, h) in valid_resolutions]
res_idx = st.selectbox("Resolution", options=list(range(len(res_labels))), format_func=lambda i: res_labels[i], index=0)
width, height = valid_resolutions[res_idx]
with col2:
n_variants = st.slider("Number of Variants", min_value=1, max_value=4, value=1)
# --- Video Generation Logic ---
generate = st.button("Generate Video")
status_placeholder = st.empty()
video_storage = VideoStorage()
if generate:
# Validate required fields
if not api_key or not endpoint or not prompt.strip():
st.error("Please provide all required fields.")
else:
status_placeholder.info("Starting video generation...")
sora = SoraClient(api_key, endpoint)
job = VideoJob(sora, prompt, height=height, width=width, n_seconds=n_seconds, n_variants=n_variants)
if job.run():
saved_files = job.download_videos(video_storage.storage_dir)
if saved_files:
status_placeholder.success(f"Video(s) generated and saved: {', '.join([os.path.basename(f) for f in saved_files])}")
else:
status_placeholder.error("Video generation succeeded but download failed.")
else:
status_placeholder.error("Video generation failed.")
# --- Display All Generated Videos ---
st.header("All Generated Videos")
video_files = video_storage.list_videos()
if not video_files:
st.info("No videos generated yet.")
else:
for video_path in sorted(video_files, reverse=True):
st.video(video_path)
st.caption(os.path.basename(video_path))
st.download_button("Download", data=open(video_path, "rb").read(), file_name=os.path.basename(video_path), mime="video/mp4")