|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks, Depends, Request |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Dict, Optional, Union |
|
|
import cv2 |
|
|
import uuid |
|
|
import os |
|
|
import requests |
|
|
import random |
|
|
import string |
|
|
import json |
|
|
import shutil |
|
|
import ast |
|
|
import tempfile |
|
|
import asyncio |
|
|
import time |
|
|
import logging |
|
|
|
|
|
|
|
|
app = FastAPI(title="Advanced NSFW Video Detector API", version="1.1.0") |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
templates_path = os.path.join(os.path.dirname(__file__), "templates") |
|
|
if not os.path.exists(templates_path): |
|
|
os.makedirs(templates_path) |
|
|
templates = Jinja2Templates(directory=templates_path) |
|
|
|
|
|
dummy_html_path = os.path.join(templates_path, "index.html") |
|
|
if not os.path.exists(dummy_html_path): |
|
|
with open(dummy_html_path, "w") as f: |
|
|
f.write("<h1>Dummy Index Page - Replace with actual instructions</h1>") |
|
|
except Exception as e: |
|
|
logger.warning(f"Jinja2Templates initialization failed: {e}. Will use basic HTML string for homepage.") |
|
|
templates = None |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_REQUEST_TIMEOUT = 20 |
|
|
MAX_RETRY_ATTEMPTS = 3 |
|
|
RETRY_BACKOFF_FACTOR = 2 |
|
|
|
|
|
|
|
|
NSFW_CHECKER_CONFIG = { |
|
|
"checker1_yoinked": { |
|
|
"queue_join_url": "https://yoinked-da-nsfw-checker.hf.space/queue/join", |
|
|
"queue_data_url_template": "https://yoinked-da-nsfw-checker.hf.space/queue/data?session_hash={session_hash}", |
|
|
"payload_template": lambda img_url, session_hash: { |
|
|
'data': [{'path': img_url}, "chen-convnext", 0.5, True, True], |
|
|
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 12 |
|
|
} |
|
|
}, |
|
|
"checker2_jamescookjr90": { |
|
|
"queue_join_url": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/join", |
|
|
"queue_data_url_template": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}", |
|
|
"payload_template": lambda img_url, session_hash: { |
|
|
'data': [{'path': img_url}], |
|
|
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9 |
|
|
} |
|
|
}, |
|
|
"checker3_zanderlewis": { |
|
|
"predict_url": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict", |
|
|
"event_url_template": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict/{event_id}", |
|
|
"payload_template": lambda img_url: {'data': [{'path': img_url}]} |
|
|
}, |
|
|
"checker4_error466": { |
|
|
"base_url": "https://error466-falconsai-nsfw-image-detection.hf.space", |
|
|
"replica_code_needed": True, |
|
|
"queue_join_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/join", |
|
|
"queue_data_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/data?session_hash={session_hash}", |
|
|
"payload_template": lambda img_url, session_hash: { |
|
|
'data': [{'path': img_url}], |
|
|
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 58 |
|
|
} |
|
|
}, |
|
|
"checker5_phelpsgg": { |
|
|
"queue_join_url": "https://imseldrith-falconsai-nsfw-image-detection.hf.space/queue/join", |
|
|
"queue_data_url_template": "https://imseldrith-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}", |
|
|
"payload_template": lambda img_url, session_hash: { |
|
|
'data': [{'path': img_url}], |
|
|
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9 |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
tasks_db: Dict[str, Dict] = {} |
|
|
|
|
|
|
|
|
async def http_request_with_retry(method: str, url: str, **kwargs) -> Optional[requests.Response]: |
|
|
"""Makes an HTTP request with retries, exponential backoff, and jitter.""" |
|
|
headers = kwargs.pop("headers", {}) |
|
|
headers.setdefault("User-Agent", "NSFWDetectorClient/1.1") |
|
|
|
|
|
for attempt in range(MAX_RETRY_ATTEMPTS): |
|
|
try: |
|
|
async with asyncio.Semaphore(10): |
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
response = await loop.run_in_executor( |
|
|
None, |
|
|
lambda: requests.request(method, url, headers=headers, timeout=DEFAULT_REQUEST_TIMEOUT, **kwargs) |
|
|
) |
|
|
response.raise_for_status() |
|
|
return response |
|
|
except requests.exceptions.Timeout: |
|
|
logger.warning(f"Request timeout for {url} on attempt {attempt + 1}") |
|
|
except requests.exceptions.HTTPError as e: |
|
|
if e.response is not None and e.response.status_code in [429, 502, 503, 504]: |
|
|
logger.warning(f"HTTP error {e.response.status_code} for {url} on attempt {attempt + 1}") |
|
|
else: |
|
|
logger.error(f"Non-retriable HTTP error for {url}: {e}") |
|
|
return e.response if e.response is not None else None |
|
|
except requests.exceptions.RequestException as e: |
|
|
logger.error(f"Request exception for {url} on attempt {attempt + 1}: {e}") |
|
|
|
|
|
if attempt < MAX_RETRY_ATTEMPTS - 1: |
|
|
delay = (RETRY_BACKOFF_FACTOR ** attempt) + random.uniform(0, 0.5) |
|
|
logger.info(f"Retrying {url} in {delay:.2f} seconds...") |
|
|
await asyncio.sleep(delay) |
|
|
logger.error(f"All {MAX_RETRY_ATTEMPTS} retry attempts failed for {url}.") |
|
|
return None |
|
|
|
|
|
def get_replica_code_sync(url: str) -> Optional[str]: |
|
|
try: |
|
|
r = requests.get(url, timeout=DEFAULT_REQUEST_TIMEOUT, headers={"User-Agent": "NSFWDetectorClient/1.1"}) |
|
|
r.raise_for_status() |
|
|
|
|
|
parts = r.text.split('replicas/') |
|
|
if len(parts) > 1: |
|
|
return parts[1].split('"};')[0] |
|
|
logger.warning(f"Could not find 'replicas/' in content from {url}") |
|
|
return None |
|
|
except (requests.exceptions.RequestException, IndexError, KeyError) as e: |
|
|
logger.error(f"Error getting replica code for {url}: {e}") |
|
|
return None |
|
|
|
|
|
async def get_replica_code(url: str) -> Optional[str]: |
|
|
loop = asyncio.get_event_loop() |
|
|
return await loop.run_in_executor(None, get_replica_code_sync, url) |
|
|
|
|
|
|
|
|
async def parse_hf_queue_response(response_content: str) -> Optional[str]: |
|
|
try: |
|
|
messages = response_content.strip().split('\n') |
|
|
for msg_str in reversed(messages): |
|
|
if msg_str.startswith("data:"): |
|
|
try: |
|
|
data_json_str = msg_str[len("data:"):].strip() |
|
|
if not data_json_str: continue |
|
|
|
|
|
parsed_json = json.loads(data_json_str) |
|
|
if parsed_json.get("msg") == "process_completed": |
|
|
output_data = parsed_json.get("output", {}).get("data") |
|
|
if output_data and isinstance(output_data, list) and len(output_data) > 0: |
|
|
first_item = output_data[0] |
|
|
if isinstance(first_item, dict): return first_item.get('label') |
|
|
if isinstance(first_item, str): return first_item |
|
|
logger.warning(f"Unexpected 'process_completed' data structure: {output_data}") |
|
|
return None |
|
|
except json.JSONDecodeError: |
|
|
logger.debug(f"Failed to decode JSON from part of HF stream: {data_json_str[:100]}") |
|
|
continue |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Error parsing HF queue response: {e}, content: {response_content[:200]}") |
|
|
return None |
|
|
|
|
|
async def check_nsfw_single_generic(checker_name: str, img_url: str) -> Optional[str]: |
|
|
config = NSFW_CHECKER_CONFIG.get(checker_name) |
|
|
if not config: |
|
|
logger.error(f"No configuration found for checker: {checker_name}") |
|
|
return None |
|
|
|
|
|
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(10)) |
|
|
|
|
|
try: |
|
|
if "predict_url" in config: |
|
|
payload = config["payload_template"](img_url) |
|
|
response_predict = await http_request_with_retry("POST", config["predict_url"], json=payload) |
|
|
if not response_predict or response_predict.status_code != 200: |
|
|
logger.error(f"{checker_name} predict call failed or returned non-200. Status: {response_predict.status_code if response_predict else 'N/A'}") |
|
|
return None |
|
|
|
|
|
json_data = response_predict.json() |
|
|
event_id = json_data.get('event_id') |
|
|
if not event_id: |
|
|
logger.error(f"{checker_name} did not return event_id.") |
|
|
return None |
|
|
|
|
|
event_url = config["event_url_template"].format(event_id=event_id) |
|
|
for _ in range(10): |
|
|
await asyncio.sleep(random.uniform(1.5, 2.5)) |
|
|
response_event = await http_request_with_retry("GET", event_url, stream=True) |
|
|
if response_event and response_event.status_code == 200: |
|
|
event_stream_content = response_event.text |
|
|
if 'data:' in event_stream_content: |
|
|
final_data_str = event_stream_content.strip().split('data:')[-1].strip() |
|
|
if final_data_str: |
|
|
try: |
|
|
parsed_list = ast.literal_eval(final_data_str) |
|
|
if isinstance(parsed_list, list) and parsed_list: |
|
|
return parsed_list[0].get('label') |
|
|
logger.warning(f"{checker_name} parsed empty or invalid list from event stream: {final_data_str[:100]}") |
|
|
except (SyntaxError, ValueError, IndexError, TypeError) as e: |
|
|
logger.warning(f"{checker_name} error parsing event stream: {e}, stream: {final_data_str[:100]}") |
|
|
elif response_event: |
|
|
logger.warning(f"{checker_name} polling event_url returned status {response_event.status_code}") |
|
|
else: |
|
|
logger.warning(f"{checker_name} polling event_url got no response.") |
|
|
|
|
|
else: |
|
|
join_url = config["queue_join_url"] |
|
|
data_url_template = config["queue_data_url_template"] |
|
|
|
|
|
if config.get("replica_code_needed"): |
|
|
replica_base_url = config.get("base_url") |
|
|
if not replica_base_url: |
|
|
logger.error(f"{checker_name} needs replica_code but base_url is missing.") |
|
|
return None |
|
|
code = await get_replica_code(replica_base_url) |
|
|
if not code: |
|
|
logger.error(f"Failed to get replica code for {checker_name}") |
|
|
return None |
|
|
join_url = config["queue_join_url_template"].format(code=code) |
|
|
data_url = data_url_template.format(code=code, session_hash=session_hash) |
|
|
else: |
|
|
data_url = data_url_template.format(session_hash=session_hash) |
|
|
|
|
|
payload = config["payload_template"](img_url, session_hash) |
|
|
|
|
|
response_join = await http_request_with_retry("POST", join_url, json=payload) |
|
|
if not response_join or response_join.status_code != 200: |
|
|
logger.error(f"{checker_name} queue/join call failed. Status: {response_join.status_code if response_join else 'N/A'}") |
|
|
return None |
|
|
|
|
|
for _ in range(15): |
|
|
await asyncio.sleep(random.uniform(1.5, 2.5)) |
|
|
response_data = await http_request_with_retry("GET", data_url, stream=True) |
|
|
if response_data and response_data.status_code == 200: |
|
|
buffer = "" |
|
|
for content_chunk in response_data.iter_content(chunk_size=1024, decode_unicode=True): |
|
|
if content_chunk: |
|
|
buffer += content_chunk |
|
|
if buffer.strip().endswith("}\n\n"): |
|
|
label = await parse_hf_queue_response(buffer) |
|
|
if label: return label |
|
|
buffer = "" |
|
|
elif response_data: |
|
|
logger.warning(f"{checker_name} polling queue/data returned status {response_data.status_code}") |
|
|
else: |
|
|
logger.warning(f"{checker_name} polling queue/data got no response.") |
|
|
|
|
|
logger.warning(f"{checker_name} failed to get a conclusive result for {img_url}") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Exception in {checker_name} for {img_url}: {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
async def check_nsfw_final_concurrent(img_url: str) -> Optional[bool]: |
|
|
logger.info(f"Starting NSFW check for: {img_url}") |
|
|
checker_names = [ |
|
|
"checker2_jamescookjr90", "checker3_zanderlewis", "checker5_phelpsgg", |
|
|
"checker4_error466", "checker1_yoinked" |
|
|
] |
|
|
|
|
|
|
|
|
named_tasks = { |
|
|
name: asyncio.create_task(check_nsfw_single_generic(name, img_url)) |
|
|
for name in checker_names |
|
|
} |
|
|
|
|
|
|
|
|
sfw_found_by_any_checker = False |
|
|
|
|
|
for task_name in named_tasks: |
|
|
try: |
|
|
label = await named_tasks[task_name] |
|
|
logger.info(f"Checker '{task_name}' result for {img_url}: {label}") |
|
|
if label: |
|
|
label_lower = label.lower() |
|
|
if 'nsfw' in label_lower: |
|
|
logger.info(f"NSFW detected by '{task_name}' for {img_url}. Final: True.") |
|
|
|
|
|
|
|
|
|
|
|
return True |
|
|
if 'sfw' in label_lower or 'safe' in label_lower: |
|
|
sfw_found_by_any_checker = True |
|
|
|
|
|
|
|
|
except asyncio.CancelledError: |
|
|
logger.info(f"Checker '{task_name}' was cancelled for {img_url}.") |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing result from checker '{task_name}' for {img_url}: {e}") |
|
|
|
|
|
if sfw_found_by_any_checker: |
|
|
logger.info(f"SFW confirmed for {img_url} (no NSFW detected, at least one SFW). Final: False.") |
|
|
return False |
|
|
|
|
|
logger.warning(f"All NSFW checkers inconclusive or failed for {img_url}. Final: None.") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
BASE_FRAMES_DIR = "/tmp/video_frames_service_advanced_v2" |
|
|
os.makedirs(BASE_FRAMES_DIR, exist_ok=True) |
|
|
app.mount("/static_frames", StaticFiles(directory=BASE_FRAMES_DIR), name="static_frames") |
|
|
|
|
|
def extract_frames_sync(video_path, num_frames_to_extract, request_specific_frames_dir): |
|
|
vidcap = cv2.VideoCapture(video_path) |
|
|
if not vidcap.isOpened(): |
|
|
logger.error(f"Cannot open video file: {video_path}") |
|
|
return [] |
|
|
total_frames_in_video = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
extracted_frame_paths = [] |
|
|
|
|
|
if total_frames_in_video == 0: |
|
|
logger.warning(f"Video {video_path} has no frames.") |
|
|
vidcap.release() |
|
|
return [] |
|
|
|
|
|
|
|
|
actual_frames_to_extract = min(num_frames_to_extract, total_frames_in_video) |
|
|
if actual_frames_to_extract == 0 and total_frames_in_video > 0: |
|
|
actual_frames_to_extract = 1 |
|
|
|
|
|
if actual_frames_to_extract == 0: |
|
|
vidcap.release() |
|
|
return [] |
|
|
|
|
|
|
|
|
for i in range(actual_frames_to_extract): |
|
|
|
|
|
frame_number = int(i * total_frames_in_video / actual_frames_to_extract) if actual_frames_to_extract > 0 else 0 |
|
|
|
|
|
frame_number = min(frame_number, total_frames_in_video -1) |
|
|
|
|
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) |
|
|
success, image = vidcap.read() |
|
|
if success: |
|
|
frame_filename = os.path.join(request_specific_frames_dir, f"frame_{uuid.uuid4().hex}.jpg") |
|
|
if cv2.imwrite(frame_filename, image): |
|
|
extracted_frame_paths.append(frame_filename) |
|
|
else: |
|
|
logger.error(f"Failed to write frame: {frame_filename}") |
|
|
else: |
|
|
logger.warning(f"Failed to read frame at position {frame_number} from {video_path}. Total frames: {total_frames_in_video}") |
|
|
|
|
|
vidcap.release() |
|
|
return extracted_frame_paths |
|
|
|
|
|
async def process_video_core(task_id: str, video_path_on_disk: str, num_frames_to_analyze: int, app_base_url: str): |
|
|
tasks_db[task_id].update({"status": "processing", "message": "Extracting frames..."}) |
|
|
|
|
|
request_frames_subdir = os.path.join(BASE_FRAMES_DIR, task_id) |
|
|
os.makedirs(request_frames_subdir, exist_ok=True) |
|
|
|
|
|
extracted_frames_disk_paths = [] |
|
|
try: |
|
|
loop = asyncio.get_event_loop() |
|
|
extracted_frames_disk_paths = await loop.run_in_executor( |
|
|
None, extract_frames_sync, video_path_on_disk, num_frames_to_analyze, request_frames_subdir |
|
|
) |
|
|
|
|
|
if not extracted_frames_disk_paths: |
|
|
tasks_db[task_id].update({"status": "failed", "message": "No frames could be extracted."}) |
|
|
logger.error(f"Task {task_id}: No frames extracted from {video_path_on_disk}") |
|
|
|
|
|
if os.path.exists(video_path_on_disk): os.remove(video_path_on_disk) |
|
|
return |
|
|
|
|
|
tasks_db[task_id].update({ |
|
|
"status": "processing", |
|
|
"message": f"Analyzing {len(extracted_frames_disk_paths)} frames..." |
|
|
}) |
|
|
|
|
|
nsfw_count = 0 |
|
|
frame_results_list = [] |
|
|
base_url_for_static_frames = f"{app_base_url.rstrip('/')}/static_frames/{task_id}" |
|
|
|
|
|
analysis_coroutines = [] |
|
|
for frame_disk_path in extracted_frames_disk_paths: |
|
|
frame_filename_only = os.path.basename(frame_disk_path) |
|
|
img_http_url = f"{base_url_for_static_frames}/{frame_filename_only}" |
|
|
analysis_coroutines.append(check_nsfw_final_concurrent(img_http_url)) |
|
|
|
|
|
nsfw_detection_results = await asyncio.gather(*analysis_coroutines, return_exceptions=True) |
|
|
|
|
|
for i, detection_result in enumerate(nsfw_detection_results): |
|
|
frame_disk_path = extracted_frames_disk_paths[i] |
|
|
frame_filename_only = os.path.basename(frame_disk_path) |
|
|
img_http_url = f"{base_url_for_static_frames}/{frame_filename_only}" |
|
|
is_nsfw_str = "unknown" |
|
|
|
|
|
if isinstance(detection_result, Exception): |
|
|
logger.error(f"Task {task_id}: Error analyzing frame {img_http_url}: {detection_result}") |
|
|
is_nsfw_str = "error" |
|
|
else: |
|
|
if detection_result is True: |
|
|
nsfw_count += 1 |
|
|
is_nsfw_str = "true" |
|
|
elif detection_result is False: |
|
|
is_nsfw_str = "false" |
|
|
|
|
|
frame_results_list.append({"frame_url": img_http_url, "nsfw_detected": is_nsfw_str}) |
|
|
|
|
|
result_summary = { |
|
|
"nsfw_count": nsfw_count, |
|
|
"total_frames_analyzed": len(extracted_frames_disk_paths), |
|
|
"frames": frame_results_list |
|
|
} |
|
|
tasks_db[task_id].update({"status": "completed", "result": result_summary, "message": "Processing complete."}) |
|
|
logger.info(f"Task {task_id}: Processing complete. Result: {result_summary}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Task {task_id}: Unhandled exception in process_video_core: {e}", exc_info=True) |
|
|
tasks_db[task_id].update({"status": "failed", "message": f"An internal error occurred: {str(e)}"}) |
|
|
finally: |
|
|
if os.path.exists(video_path_on_disk): |
|
|
try: |
|
|
os.remove(video_path_on_disk) |
|
|
logger.info(f"Task {task_id}: Cleaned up video file: {video_path_on_disk}") |
|
|
except OSError as e_remove: |
|
|
logger.error(f"Task {task_id}: Error cleaning up video file {video_path_on_disk}: {e_remove}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VideoProcessRequest(BaseModel): |
|
|
video_url: Optional[str] = Field(None, description="Publicly accessible URL of the video to process.") |
|
|
num_frames: int = Field(10, gt=0, le=50, description="Number of frames to extract (1-50). Max 50 for performance.") |
|
|
app_base_url: str = Field(..., description="Public base URL of this API service (e.g., https://your-username-your-space-name.hf.space).") |
|
|
|
|
|
class TaskCreationResponse(BaseModel): |
|
|
task_id: str |
|
|
status_url: str |
|
|
message: str |
|
|
|
|
|
class FrameResult(BaseModel): |
|
|
frame_url: str |
|
|
nsfw_detected: str |
|
|
|
|
|
class VideoProcessResult(BaseModel): |
|
|
nsfw_count: int |
|
|
total_frames_analyzed: int |
|
|
frames: List[FrameResult] |
|
|
|
|
|
class TaskStatusResponse(BaseModel): |
|
|
task_id: str |
|
|
status: str |
|
|
message: Optional[str] = None |
|
|
result: Optional[VideoProcessResult] = None |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/process_video_async", response_model=TaskCreationResponse, status_code=202) |
|
|
async def process_video_from_url_async_endpoint( |
|
|
request_data: VideoProcessRequest, |
|
|
background_tasks: BackgroundTasks |
|
|
): |
|
|
if not request_data.video_url: |
|
|
raise HTTPException(status_code=400, detail="video_url must be provided.") |
|
|
|
|
|
task_id = str(uuid.uuid4()) |
|
|
tasks_db[task_id] = {"status": "pending", "message": "Task received, preparing for download."} |
|
|
|
|
|
temp_video_file_path = None |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task_download_dir = os.path.join(BASE_FRAMES_DIR, "_video_downloads", task_id) |
|
|
os.makedirs(task_download_dir, exist_ok=True) |
|
|
|
|
|
video_suffix = os.path.splitext(request_data.video_url.split("?")[0])[-1] or ".mp4" |
|
|
if not video_suffix.startswith("."): video_suffix = "." + video_suffix |
|
|
|
|
|
|
|
|
temp_video_file_path = os.path.join(task_download_dir, f"downloaded_video{video_suffix}") |
|
|
|
|
|
logger.info(f"Task {task_id}: Attempting to download video from {request_data.video_url} to {temp_video_file_path}") |
|
|
|
|
|
dl_response = await http_request_with_retry("GET", request_data.video_url, stream=True) |
|
|
if not dl_response or dl_response.status_code != 200: |
|
|
if os.path.exists(task_download_dir): shutil.rmtree(task_download_dir) |
|
|
tasks_db[task_id].update({"status": "failed", "message": f"Failed to download video. Status: {dl_response.status_code if dl_response else 'N/A'}"}) |
|
|
raise HTTPException(status_code=400, detail=f"Error downloading video: Status {dl_response.status_code if dl_response else 'N/A'}") |
|
|
|
|
|
with open(temp_video_file_path, "wb") as f: |
|
|
for chunk in dl_response.iter_content(chunk_size=8192*4): |
|
|
f.write(chunk) |
|
|
logger.info(f"Task {task_id}: Video downloaded to {temp_video_file_path}") |
|
|
|
|
|
background_tasks.add_task(process_video_core, task_id, temp_video_file_path, request_data.num_frames, request_data.app_base_url) |
|
|
|
|
|
status_url_path = app.url_path_for("get_task_status_endpoint", task_id=task_id) |
|
|
full_status_url = str(request_data.app_base_url.rstrip('/') + status_url_path) |
|
|
|
|
|
return TaskCreationResponse( |
|
|
task_id=task_id, |
|
|
status_url=full_status_url, |
|
|
message="Video processing task accepted and started in background." |
|
|
) |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path)) |
|
|
tasks_db[task_id].update({"status": "failed", "message": f"Video download error: {e}"}) |
|
|
raise HTTPException(status_code=400, detail=f"Error downloading video: {e}") |
|
|
except Exception as e: |
|
|
if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path)) |
|
|
logger.error(f"Task {task_id}: Unexpected error during task submission: {e}", exc_info=True) |
|
|
tasks_db[task_id].update({"status": "failed", "message": "Internal server error during task submission."}) |
|
|
raise HTTPException(status_code=500, detail="Internal server error during task submission.") |
|
|
|
|
|
|
|
|
@app.post("/upload_video_async", response_model=TaskCreationResponse, status_code=202) |
|
|
async def upload_video_async_endpoint( |
|
|
background_tasks: BackgroundTasks, |
|
|
app_base_url: str = Form(..., description="Public base URL of this API service."), |
|
|
num_frames: int = Form(10, gt=0, le=50, description="Number of frames to extract (1-50)."), |
|
|
video_file: UploadFile = File(..., description="Video file to upload and process.") |
|
|
): |
|
|
if not video_file.content_type or not video_file.content_type.startswith("video/"): |
|
|
raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.") |
|
|
|
|
|
task_id = str(uuid.uuid4()) |
|
|
tasks_db[task_id] = {"status": "pending", "message": "Task received, saving uploaded video."} |
|
|
temp_video_file_path = None |
|
|
try: |
|
|
upload_dir = os.path.join(BASE_FRAMES_DIR, "_video_uploads", task_id) |
|
|
os.makedirs(upload_dir, exist_ok=True) |
|
|
|
|
|
suffix = os.path.splitext(video_file.filename)[1] if video_file.filename and "." in video_file.filename else ".mp4" |
|
|
if not suffix.startswith("."): suffix = "." + suffix |
|
|
|
|
|
temp_video_file_path = os.path.join(upload_dir, f"uploaded_video{suffix}") |
|
|
|
|
|
with open(temp_video_file_path, "wb") as buffer: |
|
|
shutil.copyfileobj(video_file.file, buffer) |
|
|
logger.info(f"Task {task_id}: Video uploaded and saved to {temp_video_file_path}") |
|
|
|
|
|
background_tasks.add_task(process_video_core, task_id, temp_video_file_path, num_frames, app_base_url) |
|
|
|
|
|
status_url_path = app.url_path_for("get_task_status_endpoint", task_id=task_id) |
|
|
full_status_url = str(app_base_url.rstrip('/') + status_url_path) |
|
|
return TaskCreationResponse( |
|
|
task_id=task_id, |
|
|
status_url=full_status_url, |
|
|
message="Video upload accepted and processing started in background." |
|
|
) |
|
|
except Exception as e: |
|
|
if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path)) |
|
|
logger.error(f"Task {task_id}: Error handling video upload: {e}", exc_info=True) |
|
|
tasks_db[task_id].update({"status": "failed", "message": "Internal server error during video upload."}) |
|
|
raise HTTPException(status_code=500, detail=f"Error processing uploaded file: {e}") |
|
|
finally: |
|
|
if video_file: |
|
|
await video_file.close() |
|
|
|
|
|
|
|
|
@app.get("/tasks/{task_id}/status", response_model=TaskStatusResponse) |
|
|
async def get_task_status_endpoint(task_id: str): |
|
|
task = tasks_db.get(task_id) |
|
|
if not task: |
|
|
raise HTTPException(status_code=404, detail="Task not found.") |
|
|
return TaskStatusResponse(task_id=task_id, **task) |
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def read_root(fastapi_request: Request): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_space_name = os.getenv("SPACE_ID", "your-username-your-space-name") |
|
|
if hf_space_name == "your-username-your-space-name" and fastapi_request.headers.get("host"): |
|
|
|
|
|
host = fastapi_request.headers.get("host") |
|
|
if host and ".hf.space" in host: |
|
|
hf_space_name = host |
|
|
|
|
|
|
|
|
scheme = fastapi_request.url.scheme |
|
|
port = fastapi_request.url.port |
|
|
host = fastapi_request.url.hostname |
|
|
|
|
|
if host == "localhost" or host == "127.0.0.1": |
|
|
example_app_base_url = f"{scheme}://{host}:{port}" if port else f"{scheme}://{host}" |
|
|
else: |
|
|
example_app_base_url = f"https://{hf_space_name}.hf.space" if ".hf.space" not in hf_space_name else f"https://{hf_space_name}" |
|
|
|
|
|
|
|
|
html_content = f""" |
|
|
<!DOCTYPE html> |
|
|
<html lang="en"> |
|
|
<head> |
|
|
<meta charset="UTF-8"> |
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
|
<title>NSFW Video Detector API</title> |
|
|
<style> |
|
|
body {{ font-family: Arial, sans-serif; margin: 20px; line-height: 1.6; background-color: #f4f4f4; color: #333; }} |
|
|
.container {{ background-color: #fff; padding: 20px; border-radius: 8px; box-shadow: 0 0 10px rgba(0,0,0,0.1); }} |
|
|
h1, h2, h3 {{ color: #333; }} |
|
|
h1 {{ text-align: center; border-bottom: 2px solid #eee; padding-bottom: 10px;}} |
|
|
h2 {{ border-bottom: 1px solid #eee; padding-bottom: 5px; margin-top: 30px;}} |
|
|
code {{ background-color: #eef; padding: 2px 6px; border-radius: 4px; font-family: "Courier New", Courier, monospace;}} |
|
|
pre {{ background-color: #eef; padding: 15px; border-radius: 4px; overflow-x: auto; border: 1px solid #ddd; }} |
|
|
.endpoint {{ margin-bottom: 20px; }} |
|
|
.param {{ font-weight: bold; }} |
|
|
.note {{ background-color: #fff9c4; border-left: 4px solid #fdd835; padding: 10px; margin: 15px 0; border-radius:4px; }} |
|
|
.tip {{ background-color: #e8f5e9; border-left: 4px solid #4caf50; padding: 10px; margin: 15px 0; border-radius:4px; }} |
|
|
table {{ width: 100%; border-collapse: collapse; margin-top:10px; }} |
|
|
th, td {{ text-align: left; padding: 8px; border-bottom: 1px solid #ddd; }} |
|
|
th {{ background-color: #f0f0f0; }} |
|
|
a {{ color: #007bff; text-decoration: none; }} |
|
|
a:hover {{ text-decoration: underline; }} |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<div class="container"> |
|
|
<h1>NSFW Video Detector API</h1> |
|
|
<p>This API allows you to process videos to detect Not Suitable For Work (NSFW) content. It works asynchronously: you submit a video (via URL or direct upload), receive a task ID, and then poll a status endpoint to get the results.</p> |
|
|
|
|
|
<div class="note"> |
|
|
<p><span class="param">Important:</span> The <code>app_base_url</code> parameter is crucial. It must be the public base URL where this API service is accessible. For example, if your Hugging Face Space URL is <code>https://your-username-your-space-name.hf.space</code>, then that's your <code>app_base_url</code>.</p> |
|
|
<p>Current detected example base URL for instructions: <code>{example_app_base_url}</code> (This is a guess, please verify your actual public URL).</p> |
|
|
</div> |
|
|
|
|
|
<h2>Endpoints</h2> |
|
|
|
|
|
<div class="endpoint"> |
|
|
<h3>1. Process Video from URL (Asynchronous)</h3> |
|
|
<p><code>POST /process_video_async</code></p> |
|
|
<p>Submits a video from a public URL for NSFW analysis.</p> |
|
|
<h4>Request Body (JSON):</h4> |
|
|
<table> |
|
|
<tr><th>Parameter</th><th>Type</th><th>Default</th><th>Description</th></tr> |
|
|
<tr><td><span class="param">video_url</span></td><td>string</td><td><em>Required</em></td><td>Publicly accessible URL of the video.</td></tr> |
|
|
<tr><td><span class="param">num_frames</span></td><td>integer</td><td>10</td><td>Number of frames to extract and analyze (1-50).</td></tr> |
|
|
<tr><td><span class="param">app_base_url</span></td><td>string</td><td><em>Required</em></td><td>The public base URL of this API service.</td></tr> |
|
|
</table> |
|
|
<h4>Example using <code>curl</code>:</h4> |
|
|
<pre><code>curl -X POST "{example_app_base_url}/process_video_async" \\ |
|
|
-H "Content-Type: application/json" \\ |
|
|
-d '{{ |
|
|
"video_url": "YOUR_PUBLIC_VIDEO_URL_HERE.mp4", |
|
|
"num_frames": 5, |
|
|
"app_base_url": "{example_app_base_url}" |
|
|
}}'</code></pre> |
|
|
</div> |
|
|
|
|
|
<div class="endpoint"> |
|
|
<h3>2. Upload Video File (Asynchronous)</h3> |
|
|
<p><code>POST /upload_video_async</code></p> |
|
|
<p>Uploads a video file directly for NSFW analysis.</p> |
|
|
<h4>Request Body (Multipart Form-Data):</h4> |
|
|
<table> |
|
|
<tr><th>Parameter</th><th>Type</th><th>Default</th><th>Description</th></tr> |
|
|
<tr><td><span class="param">video_file</span></td><td>file</td><td><em>Required</em></td><td>The video file to upload.</td></tr> |
|
|
<tr><td><span class="param">num_frames</span></td><td>integer</td><td>10</td><td>Number of frames to extract (1-50).</td></tr> |
|
|
<tr><td><span class="param">app_base_url</span></td><td>string</td><td><em>Required</em></td><td>The public base URL of this API service.</td></tr> |
|
|
</table> |
|
|
<h4>Example using <code>curl</code>:</h4> |
|
|
<pre><code>curl -X POST "{example_app_base_url}/upload_video_async" \\ |
|
|
-F "video_file=@/path/to/your/video.mp4" \\ |
|
|
-F "num_frames=5" \\ |
|
|
-F "app_base_url={example_app_base_url}"</code></pre> |
|
|
</div> |
|
|
|
|
|
<div class="tip"> |
|
|
<h4>Response for Task Creation (for both URL and Upload):</h4> |
|
|
<p>If successful (HTTP 202 Accepted), the API will respond with:</p> |
|
|
<pre><code>{{ |
|
|
"task_id": "some-unique-task-id", |
|
|
"status_url": "{example_app_base_url}/tasks/some-unique-task-id/status", |
|
|
"message": "Video processing task accepted and started in background." |
|
|
}}</code></pre> |
|
|
</div> |
|
|
|
|
|
<div class="endpoint"> |
|
|
<h3>3. Get Task Status and Result</h3> |
|
|
<p><code>GET /tasks/<task_id>/status</code></p> |
|
|
<p>Poll this endpoint to check the status of a processing task and retrieve the result once completed.</p> |
|
|
<h4>Example using <code>curl</code>:</h4> |
|
|
<pre><code>curl -X GET "{example_app_base_url}/tasks/some-unique-task-id/status"</code></pre> |
|
|
<h4>Possible Statuses:</h4> |
|
|
<ul> |
|
|
<li><code>pending</code>: Task is queued.</li> |
|
|
<li><code>processing</code>: Task is actively being processed (downloading, extracting frames, analyzing).</li> |
|
|
<li><code>completed</code>: Task finished successfully. Results are available in the <code>result</code> field.</li> |
|
|
<li><code>failed</code>: Task failed. Check the <code>message</code> field for details.</li> |
|
|
</ul> |
|
|
<h4>Example Response (Status: <code>completed</code>):</h4> |
|
|
<pre><code>{{ |
|
|
"task_id": "some-unique-task-id", |
|
|
"status": "completed", |
|
|
"message": "Processing complete.", |
|
|
"result": {{ |
|
|
"nsfw_count": 1, |
|
|
"total_frames_analyzed": 5, |
|
|
"frames": [ |
|
|
{{ |
|
|
"frame_url": "{example_app_base_url}/static_frames/some-unique-task-id/frame_uuid1.jpg", |
|
|
"nsfw_detected": "false" |
|
|
}}, |
|
|
{{ |
|
|
"frame_url": "{example_app_base_url}/static_frames/some-unique-task-id/frame_uuid2.jpg", |
|
|
"nsfw_detected": "true" |
|
|
}} |
|
|
// ... more frames |
|
|
] |
|
|
}} |
|
|
}}</code></pre> |
|
|
<h4>Example Response (Status: <code>processing</code>):</h4> |
|
|
<pre><code>{{ |
|
|
"task_id": "some-unique-task-id", |
|
|
"status": "processing", |
|
|
"message": "Analyzing 5 frames...", |
|
|
"result": null |
|
|
}}</code></pre> |
|
|
</div> |
|
|
<p style="text-align:center; margin-top:30px; font-size:0.9em; color:#777;">API Version: {app.version}</p> |
|
|
</div> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return HTMLResponse(content=html_content, status_code=200) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|