Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks, Depends, Request | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.templating import Jinja2Templates # For serving HTML | |
from pydantic import BaseModel, Field | |
from typing import List, Dict, Optional, Union | |
import cv2 # OpenCV for video processing | |
import uuid # For generating unique filenames | |
import os # For interacting with the file system | |
import requests # For making HTTP requests | |
import random | |
import string | |
import json | |
import shutil # For file operations | |
import ast # For safely evaluating string literals | |
import tempfile # For creating temporary directories/files | |
import asyncio # For concurrent operations | |
import time # For retries and delays | |
import logging # For structured logging | |
# --- Application Setup --- | |
app = FastAPI(title="Advanced NSFW Video Detector API", version="1.1.0") # Updated version | |
# --- Logging Configuration --- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Templates for HTML Homepage --- | |
# Create a 'templates' directory in the same location as your main.py | |
# and put an 'index.html' file inside it. | |
# For Hugging Face Spaces, you might need to adjust path or ensure the templates dir is included. | |
# For simplicity here, I'll embed the HTML string directly if Jinja2 setup is complex for the environment. | |
# However, using Jinja2 is cleaner. Let's assume a 'templates' directory. | |
# If 'templates' dir doesn't exist, it will fall back to a basic HTML string. | |
try: | |
templates_path = os.path.join(os.path.dirname(__file__), "templates") | |
if not os.path.exists(templates_path): | |
os.makedirs(templates_path) # Create if not exists for local dev | |
templates = Jinja2Templates(directory=templates_path) | |
# Create a dummy index.html if it doesn't exist for local testing | |
dummy_html_path = os.path.join(templates_path, "index.html") | |
if not os.path.exists(dummy_html_path): | |
with open(dummy_html_path, "w") as f: | |
f.write("<h1>Dummy Index Page - Replace with actual instructions</h1>") | |
except Exception as e: | |
logger.warning(f"Jinja2Templates initialization failed: {e}. Will use basic HTML string for homepage.") | |
templates = None | |
# --- Configuration (Potentially from environment variables or a settings file) --- | |
DEFAULT_REQUEST_TIMEOUT = 20 # Increased timeout for individual NSFW checker requests | |
MAX_RETRY_ATTEMPTS = 3 | |
RETRY_BACKOFF_FACTOR = 2 # In seconds | |
# --- NSFW Checker URLs (Ideally, these would be in a config) --- | |
NSFW_CHECKER_CONFIG = { | |
"checker1_yoinked": { | |
"queue_join_url": "https://yoinked-da-nsfw-checker.hf.space/queue/join", | |
"queue_data_url_template": "https://yoinked-da-nsfw-checker.hf.space/queue/data?session_hash={session_hash}", | |
"payload_template": lambda img_url, session_hash: { | |
'data': [{'path': img_url}, "chen-convnext", 0.5, True, True], | |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 12 | |
} | |
}, | |
"checker2_jamescookjr90": { | |
"queue_join_url": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/join", | |
"queue_data_url_template": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}", | |
"payload_template": lambda img_url, session_hash: { | |
'data': [{'path': img_url}], | |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9 | |
} | |
}, | |
"checker3_zanderlewis": { | |
"predict_url": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict", | |
"event_url_template": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict/{event_id}", | |
"payload_template": lambda img_url: {'data': [{'path': img_url}]} | |
}, | |
"checker4_error466": { | |
"base_url": "https://error466-falconsai-nsfw-image-detection.hf.space", | |
"replica_code_needed": True, | |
"queue_join_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/join", | |
"queue_data_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/data?session_hash={session_hash}", | |
"payload_template": lambda img_url, session_hash: { | |
'data': [{'path': img_url}], | |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 58 | |
} | |
}, | |
"checker5_phelpsgg": { | |
"queue_join_url": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/join", | |
"queue_data_url_template": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}", | |
"payload_template": lambda img_url, session_hash: { | |
'data': [{'path': img_url}], | |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9 | |
} | |
} | |
} | |
# --- Task Management for Asynchronous Processing --- | |
tasks_db: Dict[str, Dict] = {} | |
# --- Helper Functions --- | |
async def http_request_with_retry(method: str, url: str, **kwargs) -> Optional[requests.Response]: | |
"""Makes an HTTP request with retries, exponential backoff, and jitter.""" | |
headers = kwargs.pop("headers", {}) | |
headers.setdefault("User-Agent", "NSFWDetectorClient/1.1") | |
for attempt in range(MAX_RETRY_ATTEMPTS): | |
try: | |
async with asyncio.Semaphore(10): | |
loop = asyncio.get_event_loop() | |
# For requests library, which is synchronous | |
response = await loop.run_in_executor( | |
None, | |
lambda: requests.request(method, url, headers=headers, timeout=DEFAULT_REQUEST_TIMEOUT, **kwargs) | |
) | |
response.raise_for_status() | |
return response | |
except requests.exceptions.Timeout: | |
logger.warning(f"Request timeout for {url} on attempt {attempt + 1}") | |
except requests.exceptions.HTTPError as e: | |
if e.response is not None and e.response.status_code in [429, 502, 503, 504]: | |
logger.warning(f"HTTP error {e.response.status_code} for {url} on attempt {attempt + 1}") | |
else: | |
logger.error(f"Non-retriable HTTP error for {url}: {e}") | |
return e.response if e.response is not None else None | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Request exception for {url} on attempt {attempt + 1}: {e}") | |
if attempt < MAX_RETRY_ATTEMPTS - 1: | |
delay = (RETRY_BACKOFF_FACTOR ** attempt) + random.uniform(0, 0.5) | |
logger.info(f"Retrying {url} in {delay:.2f} seconds...") | |
await asyncio.sleep(delay) | |
logger.error(f"All {MAX_RETRY_ATTEMPTS} retry attempts failed for {url}.") | |
return None | |
def get_replica_code_sync(url: str) -> Optional[str]: | |
try: | |
r = requests.get(url, timeout=DEFAULT_REQUEST_TIMEOUT, headers={"User-Agent": "NSFWDetectorClient/1.1"}) | |
r.raise_for_status() | |
# This parsing is fragile | |
parts = r.text.split('replicas/') | |
if len(parts) > 1: | |
return parts[1].split('"};')[0] | |
logger.warning(f"Could not find 'replicas/' in content from {url}") | |
return None | |
except (requests.exceptions.RequestException, IndexError, KeyError) as e: | |
logger.error(f"Error getting replica code for {url}: {e}") | |
return None | |
async def get_replica_code(url: str) -> Optional[str]: | |
loop = asyncio.get_event_loop() | |
return await loop.run_in_executor(None, get_replica_code_sync, url) | |
async def parse_hf_queue_response(response_content: str) -> Optional[str]: | |
try: | |
messages = response_content.strip().split('\n') | |
for msg_str in reversed(messages): | |
if msg_str.startswith("data:"): | |
try: | |
data_json_str = msg_str[len("data:"):].strip() | |
if not data_json_str: continue | |
parsed_json = json.loads(data_json_str) | |
if parsed_json.get("msg") == "process_completed": | |
output_data = parsed_json.get("output", {}).get("data") | |
if output_data and isinstance(output_data, list) and len(output_data) > 0: | |
first_item = output_data[0] | |
if isinstance(first_item, dict): return first_item.get('label') | |
if isinstance(first_item, str): return first_item | |
logger.warning(f"Unexpected 'process_completed' data structure: {output_data}") | |
return None | |
except json.JSONDecodeError: | |
logger.debug(f"Failed to decode JSON from part of HF stream: {data_json_str[:100]}") | |
continue | |
return None | |
except Exception as e: | |
logger.error(f"Error parsing HF queue response: {e}, content: {response_content[:200]}") | |
return None | |
async def check_nsfw_single_generic(checker_name: str, img_url: str) -> Optional[str]: | |
config = NSFW_CHECKER_CONFIG.get(checker_name) | |
if not config: | |
logger.error(f"No configuration found for checker: {checker_name}") | |
return None | |
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(10)) | |
try: | |
if "predict_url" in config: # ZanderLewis-like | |
payload = config["payload_template"](img_url) | |
response_predict = await http_request_with_retry("POST", config["predict_url"], json=payload) | |
if not response_predict or response_predict.status_code != 200: | |
logger.error(f"{checker_name} predict call failed or returned non-200. Status: {response_predict.status_code if response_predict else 'N/A'}") | |
return None | |
json_data = response_predict.json() | |
event_id = json_data.get('event_id') | |
if not event_id: | |
logger.error(f"{checker_name} did not return event_id.") | |
return None | |
event_url = config["event_url_template"].format(event_id=event_id) | |
for _ in range(10): | |
await asyncio.sleep(random.uniform(1.5, 2.5)) # Randomized poll delay | |
response_event = await http_request_with_retry("GET", event_url, stream=True) # stream=True might not be needed if not chunking | |
if response_event and response_event.status_code == 200: | |
event_stream_content = response_event.text # Get full text | |
if 'data:' in event_stream_content: | |
final_data_str = event_stream_content.strip().split('data:')[-1].strip() | |
if final_data_str: | |
try: | |
parsed_list = ast.literal_eval(final_data_str) | |
if isinstance(parsed_list, list) and parsed_list: | |
return parsed_list[0].get('label') | |
logger.warning(f"{checker_name} parsed empty or invalid list from event stream: {final_data_str[:100]}") | |
except (SyntaxError, ValueError, IndexError, TypeError) as e: | |
logger.warning(f"{checker_name} error parsing event stream: {e}, stream: {final_data_str[:100]}") | |
elif response_event: | |
logger.warning(f"{checker_name} polling event_url returned status {response_event.status_code}") | |
else: | |
logger.warning(f"{checker_name} polling event_url got no response.") | |
else: # Queue-based APIs | |
join_url = config["queue_join_url"] | |
data_url_template = config["queue_data_url_template"] | |
if config.get("replica_code_needed"): | |
replica_base_url = config.get("base_url") | |
if not replica_base_url: | |
logger.error(f"{checker_name} needs replica_code but base_url is missing.") | |
return None | |
code = await get_replica_code(replica_base_url) | |
if not code: | |
logger.error(f"Failed to get replica code for {checker_name}") | |
return None | |
join_url = config["queue_join_url_template"].format(code=code) | |
data_url = data_url_template.format(code=code, session_hash=session_hash) | |
else: | |
data_url = data_url_template.format(session_hash=session_hash) | |
payload = config["payload_template"](img_url, session_hash) | |
response_join = await http_request_with_retry("POST", join_url, json=payload) | |
if not response_join or response_join.status_code != 200: | |
logger.error(f"{checker_name} queue/join call failed. Status: {response_join.status_code if response_join else 'N/A'}") | |
return None | |
for _ in range(15): | |
await asyncio.sleep(random.uniform(1.5, 2.5)) # Randomized poll delay | |
response_data = await http_request_with_retry("GET", data_url, stream=True) # stream=True is important here | |
if response_data and response_data.status_code == 200: | |
buffer = "" | |
for content_chunk in response_data.iter_content(chunk_size=1024, decode_unicode=True): # decode_unicode | |
if content_chunk: | |
buffer += content_chunk | |
if buffer.strip().endswith("}\n\n"): # Check for complete message block | |
label = await parse_hf_queue_response(buffer) | |
if label: return label | |
buffer = "" # Reset buffer after processing a block | |
elif response_data: | |
logger.warning(f"{checker_name} polling queue/data returned status {response_data.status_code}") | |
else: | |
logger.warning(f"{checker_name} polling queue/data got no response.") | |
logger.warning(f"{checker_name} failed to get a conclusive result for {img_url}") | |
return None | |
except Exception as e: | |
logger.error(f"Exception in {checker_name} for {img_url}: {e}", exc_info=True) | |
return None | |
async def check_nsfw_final_concurrent(img_url: str) -> Optional[bool]: | |
logger.info(f"Starting NSFW check for: {img_url}") | |
checker_names = [ | |
"checker2_jamescookjr90", "checker3_zanderlewis", "checker5_phelpsgg", | |
"checker4_error466", "checker1_yoinked" | |
] | |
# Wrap tasks to carry their names for better logging upon completion | |
named_tasks = { | |
name: asyncio.create_task(check_nsfw_single_generic(name, img_url)) | |
for name in checker_names | |
} | |
# To store if any SFW result was found | |
sfw_found_by_any_checker = False | |
for task_name in named_tasks: # Iterate in defined order for potential preference | |
try: | |
label = await named_tasks[task_name] # Wait for this specific task | |
logger.info(f"Checker '{task_name}' result for {img_url}: {label}") | |
if label: | |
label_lower = label.lower() | |
if 'nsfw' in label_lower: | |
logger.info(f"NSFW detected by '{task_name}' for {img_url}. Final: True.") | |
# Optionally cancel other tasks if desired: | |
# for t_name, t_obj in named_tasks.items(): | |
# if t_name != task_name and not t_obj.done(): t_obj.cancel() | |
return True | |
if 'sfw' in label_lower or 'safe' in label_lower: | |
sfw_found_by_any_checker = True | |
# Don't return False yet, wait for other checkers. | |
# If label is None or not nsfw/sfw, continue to next checker's result | |
except asyncio.CancelledError: | |
logger.info(f"Checker '{task_name}' was cancelled for {img_url}.") | |
except Exception as e: | |
logger.error(f"Error processing result from checker '{task_name}' for {img_url}: {e}") | |
if sfw_found_by_any_checker: # No NSFW detected by any, but at least one said SFW | |
logger.info(f"SFW confirmed for {img_url} (no NSFW detected, at least one SFW). Final: False.") | |
return False | |
logger.warning(f"All NSFW checkers inconclusive or failed for {img_url}. Final: None.") | |
return None | |
# --- Video Processing Logic --- | |
BASE_FRAMES_DIR = "/tmp/video_frames_service_advanced_v2" | |
os.makedirs(BASE_FRAMES_DIR, exist_ok=True) | |
app.mount("/static_frames", StaticFiles(directory=BASE_FRAMES_DIR), name="static_frames") | |
def extract_frames_sync(video_path, num_frames_to_extract, request_specific_frames_dir): | |
vidcap = cv2.VideoCapture(video_path) | |
if not vidcap.isOpened(): | |
logger.error(f"Cannot open video file: {video_path}") | |
return [] | |
total_frames_in_video = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
extracted_frame_paths = [] | |
if total_frames_in_video == 0: | |
logger.warning(f"Video {video_path} has no frames.") | |
vidcap.release() | |
return [] | |
# Ensure num_frames_to_extract does not exceed total_frames_in_video if total_frames_in_video is small | |
actual_frames_to_extract = min(num_frames_to_extract, total_frames_in_video) | |
if actual_frames_to_extract == 0 and total_frames_in_video > 0: # Edge case: if num_frames is 0 but video has frames | |
actual_frames_to_extract = 1 # Extract at least one frame if possible | |
if actual_frames_to_extract == 0: # If still zero (e.g. total_frames_in_video was 0) | |
vidcap.release() | |
return [] | |
for i in range(actual_frames_to_extract): | |
# Distribute frame extraction | |
frame_number = int(i * total_frames_in_video / actual_frames_to_extract) if actual_frames_to_extract > 0 else 0 | |
# Ensure frame_number is within bounds | |
frame_number = min(frame_number, total_frames_in_video -1) | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) | |
success, image = vidcap.read() | |
if success: | |
frame_filename = os.path.join(request_specific_frames_dir, f"frame_{uuid.uuid4().hex}.jpg") | |
if cv2.imwrite(frame_filename, image): | |
extracted_frame_paths.append(frame_filename) | |
else: | |
logger.error(f"Failed to write frame: {frame_filename}") | |
else: | |
logger.warning(f"Failed to read frame at position {frame_number} from {video_path}. Total frames: {total_frames_in_video}") | |
# Don't break immediately, try next calculated frame unless it's a persistent issue | |
vidcap.release() | |
return extracted_frame_paths | |
async def process_video_core(task_id: str, video_path_on_disk: str, num_frames_to_analyze: int, app_base_url: str): | |
tasks_db[task_id].update({"status": "processing", "message": "Extracting frames..."}) | |
request_frames_subdir = os.path.join(BASE_FRAMES_DIR, task_id) | |
os.makedirs(request_frames_subdir, exist_ok=True) | |
extracted_frames_disk_paths = [] | |
try: | |
loop = asyncio.get_event_loop() | |
extracted_frames_disk_paths = await loop.run_in_executor( | |
None, extract_frames_sync, video_path_on_disk, num_frames_to_analyze, request_frames_subdir | |
) | |
if not extracted_frames_disk_paths: | |
tasks_db[task_id].update({"status": "failed", "message": "No frames could be extracted."}) | |
logger.error(f"Task {task_id}: No frames extracted from {video_path_on_disk}") | |
# Clean up the video file if no frames extracted | |
if os.path.exists(video_path_on_disk): os.remove(video_path_on_disk) | |
return | |
tasks_db[task_id].update({ | |
"status": "processing", | |
"message": f"Analyzing {len(extracted_frames_disk_paths)} frames..." | |
}) | |
nsfw_count = 0 | |
frame_results_list = [] | |
base_url_for_static_frames = f"{app_base_url.rstrip('/')}/static_frames/{task_id}" | |
analysis_coroutines = [] | |
for frame_disk_path in extracted_frames_disk_paths: | |
frame_filename_only = os.path.basename(frame_disk_path) | |
img_http_url = f"{base_url_for_static_frames}/{frame_filename_only}" | |
analysis_coroutines.append(check_nsfw_final_concurrent(img_http_url)) | |
nsfw_detection_results = await asyncio.gather(*analysis_coroutines, return_exceptions=True) | |
for i, detection_result in enumerate(nsfw_detection_results): | |
frame_disk_path = extracted_frames_disk_paths[i] | |
frame_filename_only = os.path.basename(frame_disk_path) | |
img_http_url = f"{base_url_for_static_frames}/{frame_filename_only}" | |
is_nsfw_str = "unknown" | |
if isinstance(detection_result, Exception): | |
logger.error(f"Task {task_id}: Error analyzing frame {img_http_url}: {detection_result}") | |
is_nsfw_str = "error" | |
else: # detection_result is True, False, or None | |
if detection_result is True: | |
nsfw_count += 1 | |
is_nsfw_str = "true" | |
elif detection_result is False: | |
is_nsfw_str = "false" | |
frame_results_list.append({"frame_url": img_http_url, "nsfw_detected": is_nsfw_str}) | |
result_summary = { | |
"nsfw_count": nsfw_count, | |
"total_frames_analyzed": len(extracted_frames_disk_paths), | |
"frames": frame_results_list | |
} | |
tasks_db[task_id].update({"status": "completed", "result": result_summary, "message": "Processing complete."}) | |
logger.info(f"Task {task_id}: Processing complete. Result: {result_summary}") | |
except Exception as e: | |
logger.error(f"Task {task_id}: Unhandled exception in process_video_core: {e}", exc_info=True) | |
tasks_db[task_id].update({"status": "failed", "message": f"An internal error occurred: {str(e)}"}) | |
finally: | |
if os.path.exists(video_path_on_disk): | |
try: | |
os.remove(video_path_on_disk) | |
logger.info(f"Task {task_id}: Cleaned up video file: {video_path_on_disk}") | |
except OSError as e_remove: | |
logger.error(f"Task {task_id}: Error cleaning up video file {video_path_on_disk}: {e_remove}") | |
# Consider a separate job for cleaning up frame directories (request_frames_subdir) after a TTL | |
# --- API Request/Response Models --- | |
class VideoProcessRequest(BaseModel): | |
video_url: Optional[str] = Field(None, description="Publicly accessible URL of the video to process.") | |
num_frames: int = Field(10, gt=0, le=50, description="Number of frames to extract (1-50). Max 50 for performance.") # Reduced max | |
app_base_url: str = Field(..., description="Public base URL of this API service (e.g., https://your-username-your-space-name.hf.space).") | |
class TaskCreationResponse(BaseModel): | |
task_id: str | |
status_url: str | |
message: str | |
class FrameResult(BaseModel): | |
frame_url: str | |
nsfw_detected: str | |
class VideoProcessResult(BaseModel): | |
nsfw_count: int | |
total_frames_analyzed: int | |
frames: List[FrameResult] | |
class TaskStatusResponse(BaseModel): | |
task_id: str | |
status: str | |
message: Optional[str] = None | |
result: Optional[VideoProcessResult] = None | |
# --- API Endpoints --- | |
async def process_video_from_url_async_endpoint( | |
request_data: VideoProcessRequest, # Changed from 'request' to avoid conflict with FastAPI's Request object | |
background_tasks: BackgroundTasks | |
): | |
if not request_data.video_url: | |
raise HTTPException(status_code=400, detail="video_url must be provided.") | |
task_id = str(uuid.uuid4()) | |
tasks_db[task_id] = {"status": "pending", "message": "Task received, preparing for download."} | |
temp_video_file_path = None | |
try: | |
# Create a temporary file path for the downloaded video | |
# The actual download will also be part of the background task to avoid blocking. | |
# For now, keeping initial download here for simplicity of passing path. | |
# A more robust way: background_tasks.add_task(download_and_then_process, task_id, request_data.video_url, ...) | |
# Using a temporary directory specific to this task for the downloaded video | |
task_download_dir = os.path.join(BASE_FRAMES_DIR, "_video_downloads", task_id) | |
os.makedirs(task_download_dir, exist_ok=True) | |
# Suffix from URL or default | |
video_suffix = os.path.splitext(request_data.video_url.split("?")[0])[-1] or ".mp4" # Basic suffix extraction | |
if not video_suffix.startswith("."): video_suffix = "." + video_suffix | |
temp_video_file_path = os.path.join(task_download_dir, f"downloaded_video{video_suffix}") | |
logger.info(f"Task {task_id}: Attempting to download video from {request_data.video_url} to {temp_video_file_path}") | |
dl_response = await http_request_with_retry("GET", request_data.video_url, stream=True) | |
if not dl_response or dl_response.status_code != 200: | |
if os.path.exists(task_download_dir): shutil.rmtree(task_download_dir) | |
tasks_db[task_id].update({"status": "failed", "message": f"Failed to download video. Status: {dl_response.status_code if dl_response else 'N/A'}"}) | |
raise HTTPException(status_code=400, detail=f"Error downloading video: Status {dl_response.status_code if dl_response else 'N/A'}") | |
with open(temp_video_file_path, "wb") as f: | |
for chunk in dl_response.iter_content(chunk_size=8192*4): # Increased chunk size | |
f.write(chunk) | |
logger.info(f"Task {task_id}: Video downloaded to {temp_video_file_path}") | |
background_tasks.add_task(process_video_core, task_id, temp_video_file_path, request_data.num_frames, request_data.app_base_url) | |
status_url_path = app.url_path_for("get_task_status_endpoint", task_id=task_id) | |
full_status_url = str(request_data.app_base_url.rstrip('/') + status_url_path) | |
return TaskCreationResponse( | |
task_id=task_id, | |
status_url=full_status_url, | |
message="Video processing task accepted and started in background." | |
) | |
except requests.exceptions.RequestException as e: | |
if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path)) | |
tasks_db[task_id].update({"status": "failed", "message": f"Video download error: {e}"}) | |
raise HTTPException(status_code=400, detail=f"Error downloading video: {e}") | |
except Exception as e: | |
if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path)) | |
logger.error(f"Task {task_id}: Unexpected error during task submission: {e}", exc_info=True) | |
tasks_db[task_id].update({"status": "failed", "message": "Internal server error during task submission."}) | |
raise HTTPException(status_code=500, detail="Internal server error during task submission.") | |
async def upload_video_async_endpoint( | |
background_tasks: BackgroundTasks, | |
app_base_url: str = Form(..., description="Public base URL of this API service."), | |
num_frames: int = Form(10, gt=0, le=50, description="Number of frames to extract (1-50)."), | |
video_file: UploadFile = File(..., description="Video file to upload and process.") | |
): | |
if not video_file.content_type or not video_file.content_type.startswith("video/"): | |
raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.") | |
task_id = str(uuid.uuid4()) | |
tasks_db[task_id] = {"status": "pending", "message": "Task received, saving uploaded video."} | |
temp_video_file_path = None | |
try: | |
upload_dir = os.path.join(BASE_FRAMES_DIR, "_video_uploads", task_id) # Task-specific upload dir | |
os.makedirs(upload_dir, exist_ok=True) | |
suffix = os.path.splitext(video_file.filename)[1] if video_file.filename and "." in video_file.filename else ".mp4" | |
if not suffix.startswith("."): suffix = "." + suffix | |
temp_video_file_path = os.path.join(upload_dir, f"uploaded_video{suffix}") | |
with open(temp_video_file_path, "wb") as buffer: | |
shutil.copyfileobj(video_file.file, buffer) | |
logger.info(f"Task {task_id}: Video uploaded and saved to {temp_video_file_path}") | |
background_tasks.add_task(process_video_core, task_id, temp_video_file_path, num_frames, app_base_url) | |
status_url_path = app.url_path_for("get_task_status_endpoint", task_id=task_id) | |
full_status_url = str(app_base_url.rstrip('/') + status_url_path) | |
return TaskCreationResponse( | |
task_id=task_id, | |
status_url=full_status_url, | |
message="Video upload accepted and processing started in background." | |
) | |
except Exception as e: | |
if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path)) | |
logger.error(f"Task {task_id}: Error handling video upload: {e}", exc_info=True) | |
tasks_db[task_id].update({"status": "failed", "message": "Internal server error during video upload."}) | |
raise HTTPException(status_code=500, detail=f"Error processing uploaded file: {e}") | |
finally: | |
if video_file: | |
await video_file.close() | |
async def get_task_status_endpoint(task_id: str): | |
task = tasks_db.get(task_id) | |
if not task: | |
raise HTTPException(status_code=404, detail="Task not found.") | |
return TaskStatusResponse(task_id=task_id, **task) | |
# --- Homepage Endpoint --- | |
async def read_root(fastapi_request: Request): # Renamed from 'request' to avoid conflict | |
# Try to determine app_base_url automatically if possible (might be tricky behind proxies) | |
# For Hugging Face, the X-Forwarded-Host or similar headers might be useful. | |
# A simpler approach for HF is to have the user provide it or construct it. | |
# For the example curl, let's use a placeholder. | |
# Construct a placeholder app_base_url for examples if running on HF | |
# This is a guess; ideally, the Space provides this as an env var. | |
hf_space_name = os.getenv("SPACE_ID", "your-username-your-space-name") | |
if hf_space_name == "your-username-your-space-name" and fastapi_request.headers.get("host"): | |
# if host header is like user-space.hf.space, use that | |
host = fastapi_request.headers.get("host") | |
if host and ".hf.space" in host: | |
hf_space_name = host | |
# If running locally, use localhost | |
scheme = fastapi_request.url.scheme | |
port = fastapi_request.url.port | |
host = fastapi_request.url.hostname | |
if host == "localhost" or host == "127.0.0.1": | |
example_app_base_url = f"{scheme}://{host}:{port}" if port else f"{scheme}://{host}" | |
else: # Assume it's deployed, e.g. on HF | |
example_app_base_url = f"https://{hf_space_name}.hf.space" if ".hf.space" not in hf_space_name else f"https://{hf_space_name}" | |
html_content = f""" | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>NSFW Video Detector API</title> | |
<style> | |
body {{ font-family: Arial, sans-serif; margin: 20px; line-height: 1.6; background-color: #f4f4f4; color: #333; }} | |
.container {{ background-color: #fff; padding: 20px; border-radius: 8px; box-shadow: 0 0 10px rgba(0,0,0,0.1); }} | |
h1, h2, h3 {{ color: #333; }} | |
h1 {{ text-align: center; border-bottom: 2px solid #eee; padding-bottom: 10px;}} | |
h2 {{ border-bottom: 1px solid #eee; padding-bottom: 5px; margin-top: 30px;}} | |
code {{ background-color: #eef; padding: 2px 6px; border-radius: 4px; font-family: "Courier New", Courier, monospace;}} | |
pre {{ background-color: #eef; padding: 15px; border-radius: 4px; overflow-x: auto; border: 1px solid #ddd; }} | |
.endpoint {{ margin-bottom: 20px; }} | |
.param {{ font-weight: bold; }} | |
.note {{ background-color: #fff9c4; border-left: 4px solid #fdd835; padding: 10px; margin: 15px 0; border-radius:4px; }} | |
.tip {{ background-color: #e8f5e9; border-left: 4px solid #4caf50; padding: 10px; margin: 15px 0; border-radius:4px; }} | |
table {{ width: 100%; border-collapse: collapse; margin-top:10px; }} | |
th, td {{ text-align: left; padding: 8px; border-bottom: 1px solid #ddd; }} | |
th {{ background-color: #f0f0f0; }} | |
a {{ color: #007bff; text-decoration: none; }} | |
a:hover {{ text-decoration: underline; }} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>NSFW Video Detector API</h1> | |
<p>This API allows you to process videos to detect Not Suitable For Work (NSFW) content. It works asynchronously: you submit a video (via URL or direct upload), receive a task ID, and then poll a status endpoint to get the results.</p> | |
<div class="note"> | |
<p><span class="param">Important:</span> The <code>app_base_url</code> parameter is crucial. It must be the public base URL where this API service is accessible. For example, if your Hugging Face Space URL is <code>https://your-username-your-space-name.hf.space</code>, then that's your <code>app_base_url</code>.</p> | |
<p>Current detected example base URL for instructions: <code>{example_app_base_url}</code> (This is a guess, please verify your actual public URL).</p> | |
</div> | |
<h2>Endpoints</h2> | |
<div class="endpoint"> | |
<h3>1. Process Video from URL (Asynchronous)</h3> | |
<p><code>POST /process_video_async</code></p> | |
<p>Submits a video from a public URL for NSFW analysis.</p> | |
<h4>Request Body (JSON):</h4> | |
<table> | |
<tr><th>Parameter</th><th>Type</th><th>Default</th><th>Description</th></tr> | |
<tr><td><span class="param">video_url</span></td><td>string</td><td><em>Required</em></td><td>Publicly accessible URL of the video.</td></tr> | |
<tr><td><span class="param">num_frames</span></td><td>integer</td><td>10</td><td>Number of frames to extract and analyze (1-50).</td></tr> | |
<tr><td><span class="param">app_base_url</span></td><td>string</td><td><em>Required</em></td><td>The public base URL of this API service.</td></tr> | |
</table> | |
<h4>Example using <code>curl</code>:</h4> | |
<pre><code>curl -X POST "{example_app_base_url}/process_video_async" \\ | |
-H "Content-Type: application/json" \\ | |
-d '{{ | |
"video_url": "YOUR_PUBLIC_VIDEO_URL_HERE.mp4", | |
"num_frames": 5, | |
"app_base_url": "{example_app_base_url}" | |
}}'</code></pre> | |
</div> | |
<div class="endpoint"> | |
<h3>2. Upload Video File (Asynchronous)</h3> | |
<p><code>POST /upload_video_async</code></p> | |
<p>Uploads a video file directly for NSFW analysis.</p> | |
<h4>Request Body (Multipart Form-Data):</h4> | |
<table> | |
<tr><th>Parameter</th><th>Type</th><th>Default</th><th>Description</th></tr> | |
<tr><td><span class="param">video_file</span></td><td>file</td><td><em>Required</em></td><td>The video file to upload.</td></tr> | |
<tr><td><span class="param">num_frames</span></td><td>integer</td><td>10</td><td>Number of frames to extract (1-50).</td></tr> | |
<tr><td><span class="param">app_base_url</span></td><td>string</td><td><em>Required</em></td><td>The public base URL of this API service.</td></tr> | |
</table> | |
<h4>Example using <code>curl</code>:</h4> | |
<pre><code>curl -X POST "{example_app_base_url}/upload_video_async" \\ | |
-F "video_file=@/path/to/your/video.mp4" \\ | |
-F "num_frames=5" \\ | |
-F "app_base_url={example_app_base_url}"</code></pre> | |
</div> | |
<div class="tip"> | |
<h4>Response for Task Creation (for both URL and Upload):</h4> | |
<p>If successful (HTTP 202 Accepted), the API will respond with:</p> | |
<pre><code>{{ | |
"task_id": "some-unique-task-id", | |
"status_url": "{example_app_base_url}/tasks/some-unique-task-id/status", | |
"message": "Video processing task accepted and started in background." | |
}}</code></pre> | |
</div> | |
<div class="endpoint"> | |
<h3>3. Get Task Status and Result</h3> | |
<p><code>GET /tasks/<task_id>/status</code></p> | |
<p>Poll this endpoint to check the status of a processing task and retrieve the result once completed.</p> | |
<h4>Example using <code>curl</code>:</h4> | |
<pre><code>curl -X GET "{example_app_base_url}/tasks/some-unique-task-id/status"</code></pre> | |
<h4>Possible Statuses:</h4> | |
<ul> | |
<li><code>pending</code>: Task is queued.</li> | |
<li><code>processing</code>: Task is actively being processed (downloading, extracting frames, analyzing).</li> | |
<li><code>completed</code>: Task finished successfully. Results are available in the <code>result</code> field.</li> | |
<li><code>failed</code>: Task failed. Check the <code>message</code> field for details.</li> | |
</ul> | |
<h4>Example Response (Status: <code>completed</code>):</h4> | |
<pre><code>{{ | |
"task_id": "some-unique-task-id", | |
"status": "completed", | |
"message": "Processing complete.", | |
"result": {{ | |
"nsfw_count": 1, | |
"total_frames_analyzed": 5, | |
"frames": [ | |
{{ | |
"frame_url": "{example_app_base_url}/static_frames/some-unique-task-id/frame_uuid1.jpg", | |
"nsfw_detected": "false" | |
}}, | |
{{ | |
"frame_url": "{example_app_base_url}/static_frames/some-unique-task-id/frame_uuid2.jpg", | |
"nsfw_detected": "true" | |
}} | |
// ... more frames | |
] | |
}} | |
}}</code></pre> | |
<h4>Example Response (Status: <code>processing</code>):</h4> | |
<pre><code>{{ | |
"task_id": "some-unique-task-id", | |
"status": "processing", | |
"message": "Analyzing 5 frames...", | |
"result": null | |
}}</code></pre> | |
</div> | |
<p style="text-align:center; margin-top:30px; font-size:0.9em; color:#777;">API Version: {app.version}</p> | |
</div> | |
</body> | |
</html> | |
""" | |
# If using Jinja2 templates: | |
# if templates: | |
# return templates.TemplateResponse("index.html", {"request": fastapi_request, "app_version": app.version, "example_app_base_url": example_app_base_url}) | |
# else: | |
# return HTMLResponse(content=html_content, status_code=200) | |
return HTMLResponse(content=html_content, status_code=200) | |
# Example of how to run for local development: | |
# 1. Ensure you have a 'templates/index.html' file or the fallback HTML will be used. | |
# 2. Run: uvicorn main:app --reload --host 0.0.0.0 --port 8000 | |
# (assuming your file is named main.py) | |
# Requirements: fastapi uvicorn[standard] opencv-python requests pydantic python-multipart (for Form/File uploads) | |