Spaces:
Sleeping
Sleeping
import os | |
import time | |
import json | |
import logging | |
import requests | |
from typing import Optional, Dict, Any, List | |
from urllib.parse import urlparse | |
# --- Logging Setup --- | |
def get_logger(name: str = __name__): | |
logger = logging.getLogger(name) | |
if not logger.handlers: | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('sora_video_downloader.log'), | |
logging.StreamHandler() | |
] | |
) | |
return logger | |
logger = get_logger() | |
# --- Sora API Client --- | |
class SoraClient: | |
def __init__(self, api_key: str, base_url: str, api_version: str = "preview"): | |
self.api_key = api_key | |
self.base_url = base_url.rstrip('/') | |
self.api_version = api_version | |
self.session = requests.Session() | |
self.session.headers.update({ | |
'api-key': self.api_key, | |
'Content-Type': 'application/json' | |
}) | |
logger.info("SoraClient initialized") | |
def start_video_job(self, prompt: str, height: int = 1080, width: int = 1080, n_seconds: int = 5, n_variants: int = 1) -> Optional[str]: | |
url = f"{self.base_url}/openai/v1/video/generations/jobs?api-version={self.api_version}" | |
payload = { | |
"model": "sora", | |
"prompt": prompt, | |
"height": str(height), | |
"width": str(width), | |
"n_seconds": str(n_seconds), | |
"n_variants": str(n_variants) | |
} | |
try: | |
response = self.session.post(url, json=payload) | |
if response.status_code in [200, 201, 202]: | |
result = response.json() | |
job_id = result.get('id') or result.get('job_id') or result.get('jobId') | |
return job_id | |
else: | |
logger.error(f"Failed to start job: {response.status_code} {response.text}") | |
return None | |
except Exception as e: | |
logger.error(f"Exception in start_video_job: {e}") | |
return None | |
def get_job_status(self, job_id: str) -> Dict[str, Any]: | |
url = f"{self.base_url}/openai/v1/video/generations/jobs/{job_id}?api-version={self.api_version}" | |
try: | |
response = self.session.get(url) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
logger.error(f"Failed to get job status: {response.status_code} {response.text}") | |
return {"status": "error", "error": f"HTTP {response.status_code}"} | |
except Exception as e: | |
logger.error(f"Exception in get_job_status: {e}") | |
return {"status": "error", "error": str(e)} | |
def wait_for_job(self, job_id: str, max_wait_time: int = 300, poll_interval: int = 10) -> Optional[Dict[str, Any]]: | |
start_time = time.time() | |
while time.time() - start_time < max_wait_time: | |
status_result = self.get_job_status(job_id) | |
status = status_result.get('status', '').lower() | |
if status in ['completed', 'succeeded', 'success']: | |
return status_result | |
elif status in ['failed', 'error']: | |
return None | |
time.sleep(poll_interval) | |
logger.error(f"Job {job_id} timed out after {max_wait_time}s") | |
return None | |
def get_generation_details(self, generation_id: str) -> Dict[str, Any]: | |
url = f"{self.base_url}/openai/v1/video/generations/{generation_id}?api-version={self.api_version}" | |
try: | |
resp = self.session.get(url) | |
if resp.status_code == 200: | |
return resp.json() | |
else: | |
logger.error(f"Failed to fetch generation details: HTTP {resp.status_code}") | |
return {} | |
except Exception as e: | |
logger.error(f"Exception in get_generation_details: {e}") | |
return {} | |
def extract_video_urls(self, job_result: Dict[str, Any]) -> List[str]: | |
video_urls = [] | |
generations = job_result.get('generations', []) | |
if isinstance(generations, list) and generations: | |
for g in generations: | |
gen_id = g.get('id') if isinstance(g, dict) else None | |
if not gen_id: | |
continue | |
content_url = f"{self.base_url}/openai/v1/video/generations/{gen_id}/content/video?api-version={self.api_version}" | |
video_urls.append(content_url) | |
return video_urls | |
def download_video(self, video_url: str, output_filename: str) -> bool: | |
try: | |
download_session = requests.Session() | |
download_session.headers.update({'api-key': self.api_key}) | |
response = download_session.get(video_url, stream=True) | |
if response.status_code == 200: | |
with open(output_filename, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
return True | |
else: | |
logger.error(f"Failed to download video: {response.status_code} {response.text}") | |
return False | |
except Exception as e: | |
logger.error(f"Exception in download_video: {e}") | |
return False | |
# --- Video Job Abstraction --- | |
class VideoJob: | |
def __init__(self, sora_client: SoraClient, prompt: str, height: int = 1080, width: int = 1080, n_seconds: int = 5, n_variants: int = 1): | |
self.sora_client = sora_client | |
self.prompt = prompt | |
self.height = height | |
self.width = width | |
self.n_seconds = n_seconds | |
self.n_variants = n_variants | |
self.job_id: Optional[str] = None | |
self.result: Optional[Dict[str, Any]] = None | |
self.video_urls: List[str] = [] | |
def run(self, wait: bool = True) -> bool: | |
self.job_id = self.sora_client.start_video_job( | |
self.prompt, self.height, self.width, self.n_seconds, self.n_variants | |
) | |
if not self.job_id: | |
logger.error("Failed to start video job") | |
return False | |
if wait: | |
self.result = self.sora_client.wait_for_job(self.job_id) | |
if not self.result: | |
logger.error("Job failed or timed out") | |
return False | |
self.video_urls = self.sora_client.extract_video_urls(self.result) | |
if not self.video_urls: | |
logger.error("No video URLs found") | |
return False | |
return True | |
def download_videos(self, output_dir: str) -> List[str]: | |
saved_files = [] | |
for i, url in enumerate(self.video_urls): | |
filename = f"sora_video_{self.job_id}_{i+1}.mp4" | |
filepath = os.path.join(output_dir, filename) | |
if self.sora_client.download_video(url, filepath): | |
saved_files.append(filepath) | |
return saved_files | |
# --- Video Storage Handler --- | |
class VideoStorage: | |
def __init__(self, storage_dir: str = "videos"): | |
self.storage_dir = storage_dir | |
os.makedirs(self.storage_dir, exist_ok=True) | |
def save_video(self, src_path: str, filename: str) -> str: | |
dst_path = os.path.join(self.storage_dir, filename) | |
os.rename(src_path, dst_path) | |
return dst_path | |
def list_videos(self) -> List[str]: | |
return [os.path.join(self.storage_dir, f) for f in os.listdir(self.storage_dir) if f.endswith('.mp4')] | |
# --- Main for CLI usage (optional) --- | |
def main(): | |
api_key = os.getenv('AZURE_OPENAI_API_KEY') or os.getenv('AZURE_API_KEY') | |
base_url = os.getenv('AZURE_OPENAI_ENDPOINT') or "https://levm3-me7f7pgq-eastus2.cognitiveservices.azure.com" | |
if not api_key or not base_url: | |
print("Please set your Azure API key and endpoint.") | |
return | |
sora = SoraClient(api_key, base_url) | |
prompt = "A video of a cat playing with a ball of yarn in a sunny room" | |
job = VideoJob(sora, prompt) | |
if job.run(): | |
storage = VideoStorage() | |
files = job.download_videos(storage.storage_dir) | |
print(f"Downloaded: {files}") | |
else: | |
print("Video generation failed.") | |
if __name__ == "__main__": | |
main() | |