Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,5 @@
|
|
1 |
-
|
2 |
-
from fastapi.staticfiles import StaticFiles
|
3 |
-
from fastapi.responses import HTMLResponse, JSONResponse
|
4 |
-
from fastapi.templating import Jinja2Templates # For serving HTML
|
5 |
-
from pydantic import BaseModel, Field
|
6 |
-
from typing import List, Dict, Optional, Union
|
7 |
import cv2 # OpenCV for video processing
|
8 |
-
import uuid # For generating unique filenames
|
9 |
import os # For interacting with the file system
|
10 |
import requests # For making HTTP requests
|
11 |
import random
|
@@ -13,46 +7,29 @@ import string
|
|
13 |
import json
|
14 |
import shutil # For file operations
|
15 |
import ast # For safely evaluating string literals
|
16 |
-
import
|
17 |
import asyncio # For concurrent operations
|
18 |
import time # For retries and delays
|
19 |
import logging # For structured logging
|
20 |
|
21 |
-
# --- Application Setup ---
|
22 |
-
app = FastAPI(title="Advanced NSFW Video Detector API", version="1.1.0") # Updated version
|
23 |
-
|
24 |
# --- Logging Configuration ---
|
25 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
28 |
-
# ---
|
29 |
-
#
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
#
|
34 |
-
#
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
os.makedirs(templates_path) # Create if not exists for local dev
|
39 |
-
templates = Jinja2Templates(directory=templates_path)
|
40 |
-
# Create a dummy index.html if it doesn't exist for local testing
|
41 |
-
dummy_html_path = os.path.join(templates_path, "index.html")
|
42 |
-
if not os.path.exists(dummy_html_path):
|
43 |
-
with open(dummy_html_path, "w") as f:
|
44 |
-
f.write("<h1>Dummy Index Page - Replace with actual instructions</h1>")
|
45 |
-
except Exception as e:
|
46 |
-
logger.warning(f"Jinja2Templates initialization failed: {e}. Will use basic HTML string for homepage.")
|
47 |
-
templates = None
|
48 |
-
|
49 |
-
|
50 |
-
# --- Configuration (Potentially from environment variables or a settings file) ---
|
51 |
-
DEFAULT_REQUEST_TIMEOUT = 20 # Increased timeout for individual NSFW checker requests
|
52 |
MAX_RETRY_ATTEMPTS = 3
|
53 |
-
RETRY_BACKOFF_FACTOR = 2
|
54 |
|
55 |
-
# --- NSFW Checker
|
56 |
NSFW_CHECKER_CONFIG = {
|
57 |
"checker1_yoinked": {
|
58 |
"queue_join_url": "https://yoinked-da-nsfw-checker.hf.space/queue/join",
|
@@ -85,9 +62,9 @@ NSFW_CHECKER_CONFIG = {
|
|
85 |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 58
|
86 |
}
|
87 |
},
|
88 |
-
"checker5_phelpsgg": {
|
89 |
-
"queue_join_url": "https://
|
90 |
-
"queue_data_url_template": "https://
|
91 |
"payload_template": lambda img_url, session_hash: {
|
92 |
'data': [{'path': img_url}],
|
93 |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9
|
@@ -95,20 +72,15 @@ NSFW_CHECKER_CONFIG = {
|
|
95 |
}
|
96 |
}
|
97 |
|
98 |
-
# ---
|
99 |
-
tasks_db: Dict[str, Dict] = {}
|
100 |
-
|
101 |
-
# --- Helper Functions ---
|
102 |
async def http_request_with_retry(method: str, url: str, **kwargs) -> Optional[requests.Response]:
|
103 |
-
"""Makes an HTTP request with retries, exponential backoff, and jitter."""
|
104 |
headers = kwargs.pop("headers", {})
|
105 |
-
headers.setdefault("User-Agent", "
|
106 |
|
107 |
for attempt in range(MAX_RETRY_ATTEMPTS):
|
108 |
try:
|
109 |
async with asyncio.Semaphore(10):
|
110 |
loop = asyncio.get_event_loop()
|
111 |
-
# For requests library, which is synchronous
|
112 |
response = await loop.run_in_executor(
|
113 |
None,
|
114 |
lambda: requests.request(method, url, headers=headers, timeout=DEFAULT_REQUEST_TIMEOUT, **kwargs)
|
@@ -135,9 +107,8 @@ async def http_request_with_retry(method: str, url: str, **kwargs) -> Optional[r
|
|
135 |
|
136 |
def get_replica_code_sync(url: str) -> Optional[str]:
|
137 |
try:
|
138 |
-
r = requests.get(url, timeout=DEFAULT_REQUEST_TIMEOUT, headers={"User-Agent": "
|
139 |
r.raise_for_status()
|
140 |
-
# This parsing is fragile
|
141 |
parts = r.text.split('replicas/')
|
142 |
if len(parts) > 1:
|
143 |
return parts[1].split('"};')[0]
|
@@ -151,7 +122,6 @@ async def get_replica_code(url: str) -> Optional[str]:
|
|
151 |
loop = asyncio.get_event_loop()
|
152 |
return await loop.run_in_executor(None, get_replica_code_sync, url)
|
153 |
|
154 |
-
|
155 |
async def parse_hf_queue_response(response_content: str) -> Optional[str]:
|
156 |
try:
|
157 |
messages = response_content.strip().split('\n')
|
@@ -191,7 +161,7 @@ async def check_nsfw_single_generic(checker_name: str, img_url: str) -> Optional
|
|
191 |
payload = config["payload_template"](img_url)
|
192 |
response_predict = await http_request_with_retry("POST", config["predict_url"], json=payload)
|
193 |
if not response_predict or response_predict.status_code != 200:
|
194 |
-
logger.error(f"{checker_name} predict call failed
|
195 |
return None
|
196 |
|
197 |
json_data = response_predict.json()
|
@@ -202,18 +172,18 @@ async def check_nsfw_single_generic(checker_name: str, img_url: str) -> Optional
|
|
202 |
|
203 |
event_url = config["event_url_template"].format(event_id=event_id)
|
204 |
for _ in range(10):
|
205 |
-
await asyncio.sleep(random.uniform(1.5, 2.5))
|
206 |
-
response_event = await http_request_with_retry("GET", event_url
|
207 |
if response_event and response_event.status_code == 200:
|
208 |
-
event_stream_content = response_event.text
|
209 |
if 'data:' in event_stream_content:
|
210 |
final_data_str = event_stream_content.strip().split('data:')[-1].strip()
|
211 |
if final_data_str:
|
212 |
try:
|
213 |
parsed_list = ast.literal_eval(final_data_str)
|
214 |
-
if isinstance(parsed_list, list) and parsed_list:
|
215 |
return parsed_list[0].get('label')
|
216 |
-
logger.warning(f"{checker_name} parsed
|
217 |
except (SyntaxError, ValueError, IndexError, TypeError) as e:
|
218 |
logger.warning(f"{checker_name} error parsing event stream: {e}, stream: {final_data_str[:100]}")
|
219 |
elif response_event:
|
@@ -247,17 +217,18 @@ async def check_nsfw_single_generic(checker_name: str, img_url: str) -> Optional
|
|
247 |
return None
|
248 |
|
249 |
for _ in range(15):
|
250 |
-
await asyncio.sleep(random.uniform(1.5, 2.5))
|
251 |
-
response_data = await http_request_with_retry("GET", data_url, stream=True)
|
252 |
if response_data and response_data.status_code == 200:
|
253 |
buffer = ""
|
254 |
-
|
|
|
255 |
if content_chunk:
|
256 |
buffer += content_chunk
|
257 |
-
if buffer.strip().endswith("}\n\n"):
|
258 |
-
label = await parse_hf_queue_response(buffer)
|
259 |
if label: return label
|
260 |
-
buffer = ""
|
261 |
elif response_data:
|
262 |
logger.warning(f"{checker_name} polling queue/data returned status {response_data.status_code}")
|
263 |
else:
|
@@ -272,488 +243,255 @@ async def check_nsfw_single_generic(checker_name: str, img_url: str) -> Optional
|
|
272 |
|
273 |
async def check_nsfw_final_concurrent(img_url: str) -> Optional[bool]:
|
274 |
logger.info(f"Starting NSFW check for: {img_url}")
|
|
|
275 |
checker_names = [
|
276 |
"checker2_jamescookjr90", "checker3_zanderlewis", "checker5_phelpsgg",
|
277 |
"checker4_error466", "checker1_yoinked"
|
278 |
]
|
279 |
|
280 |
-
# Wrap tasks to carry their names for better logging upon completion
|
281 |
named_tasks = {
|
282 |
name: asyncio.create_task(check_nsfw_single_generic(name, img_url))
|
283 |
for name in checker_names
|
284 |
}
|
285 |
|
286 |
-
# To store if any SFW result was found
|
287 |
sfw_found_by_any_checker = False
|
288 |
|
289 |
-
|
|
|
|
|
|
|
290 |
try:
|
291 |
-
label = await named_tasks[task_name]
|
292 |
logger.info(f"Checker '{task_name}' result for {img_url}: {label}")
|
293 |
if label:
|
294 |
label_lower = label.lower()
|
295 |
if 'nsfw' in label_lower:
|
296 |
logger.info(f"NSFW detected by '{task_name}' for {img_url}. Final: True.")
|
297 |
-
#
|
298 |
-
|
299 |
-
|
|
|
300 |
return True
|
301 |
if 'sfw' in label_lower or 'safe' in label_lower:
|
302 |
sfw_found_by_any_checker = True
|
303 |
-
|
304 |
-
# If label is None or not nsfw/sfw, continue to next checker's result
|
305 |
except asyncio.CancelledError:
|
306 |
logger.info(f"Checker '{task_name}' was cancelled for {img_url}.")
|
307 |
except Exception as e:
|
308 |
logger.error(f"Error processing result from checker '{task_name}' for {img_url}: {e}")
|
309 |
|
310 |
-
if sfw_found_by_any_checker:
|
311 |
-
logger.info(f"SFW confirmed for {img_url} (no NSFW detected, at least one SFW). Final: False.")
|
312 |
return False
|
313 |
|
314 |
logger.warning(f"All NSFW checkers inconclusive or failed for {img_url}. Final: None.")
|
315 |
return None
|
316 |
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
app.mount("/static_frames", StaticFiles(directory=BASE_FRAMES_DIR), name="static_frames")
|
322 |
-
|
323 |
-
def extract_frames_sync(video_path, num_frames_to_extract, request_specific_frames_dir):
|
324 |
vidcap = cv2.VideoCapture(video_path)
|
325 |
if not vidcap.isOpened():
|
326 |
-
logger.error(f"Cannot open video file
|
327 |
return []
|
|
|
328 |
total_frames_in_video = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
329 |
-
|
330 |
|
331 |
if total_frames_in_video == 0:
|
332 |
-
logger.warning(f"Video {video_path} has no frames.")
|
333 |
vidcap.release()
|
334 |
return []
|
335 |
|
336 |
-
# Ensure num_frames_to_extract does not exceed total_frames_in_video if total_frames_in_video is small
|
337 |
actual_frames_to_extract = min(num_frames_to_extract, total_frames_in_video)
|
338 |
-
if actual_frames_to_extract == 0 and total_frames_in_video > 0:
|
339 |
-
|
340 |
-
|
341 |
-
if actual_frames_to_extract == 0: # If still zero (e.g. total_frames_in_video was 0)
|
342 |
vidcap.release()
|
343 |
return []
|
344 |
|
345 |
-
|
346 |
for i in range(actual_frames_to_extract):
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
frame_number = min(frame_number, total_frames_in_video -1)
|
351 |
|
352 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
353 |
success, image = vidcap.read()
|
354 |
if success:
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
|
|
|
|
360 |
else:
|
361 |
-
logger.warning(f"Failed to read frame at position {frame_number} from {video_path}
|
362 |
-
|
363 |
vidcap.release()
|
364 |
-
|
|
|
365 |
|
366 |
-
async def
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
extracted_frames_disk_paths = await loop.run_in_executor(
|
376 |
-
None, extract_frames_sync, video_path_on_disk, num_frames_to_analyze, request_frames_subdir
|
377 |
-
)
|
378 |
-
|
379 |
-
if not extracted_frames_disk_paths:
|
380 |
-
tasks_db[task_id].update({"status": "failed", "message": "No frames could be extracted."})
|
381 |
-
logger.error(f"Task {task_id}: No frames extracted from {video_path_on_disk}")
|
382 |
-
# Clean up the video file if no frames extracted
|
383 |
-
if os.path.exists(video_path_on_disk): os.remove(video_path_on_disk)
|
384 |
-
return
|
385 |
-
|
386 |
-
tasks_db[task_id].update({
|
387 |
-
"status": "processing",
|
388 |
-
"message": f"Analyzing {len(extracted_frames_disk_paths)} frames..."
|
389 |
-
})
|
390 |
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
frame_filename_only = os.path.basename(frame_disk_path)
|
398 |
-
img_http_url = f"{base_url_for_static_frames}/{frame_filename_only}"
|
399 |
-
analysis_coroutines.append(check_nsfw_final_concurrent(img_http_url))
|
400 |
-
|
401 |
-
nsfw_detection_results = await asyncio.gather(*analysis_coroutines, return_exceptions=True)
|
402 |
-
|
403 |
-
for i, detection_result in enumerate(nsfw_detection_results):
|
404 |
-
frame_disk_path = extracted_frames_disk_paths[i]
|
405 |
-
frame_filename_only = os.path.basename(frame_disk_path)
|
406 |
-
img_http_url = f"{base_url_for_static_frames}/{frame_filename_only}"
|
407 |
-
is_nsfw_str = "unknown"
|
408 |
-
|
409 |
-
if isinstance(detection_result, Exception):
|
410 |
-
logger.error(f"Task {task_id}: Error analyzing frame {img_http_url}: {detection_result}")
|
411 |
-
is_nsfw_str = "error"
|
412 |
-
else: # detection_result is True, False, or None
|
413 |
-
if detection_result is True:
|
414 |
-
nsfw_count += 1
|
415 |
-
is_nsfw_str = "true"
|
416 |
-
elif detection_result is False:
|
417 |
-
is_nsfw_str = "false"
|
418 |
-
|
419 |
-
frame_results_list.append({"frame_url": img_http_url, "nsfw_detected": is_nsfw_str})
|
420 |
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
|
429 |
-
|
430 |
-
logger.error(f"Task {task_id}: Unhandled exception in process_video_core: {e}", exc_info=True)
|
431 |
-
tasks_db[task_id].update({"status": "failed", "message": f"An internal error occurred: {str(e)}"})
|
432 |
-
finally:
|
433 |
-
if os.path.exists(video_path_on_disk):
|
434 |
-
try:
|
435 |
-
os.remove(video_path_on_disk)
|
436 |
-
logger.info(f"Task {task_id}: Cleaned up video file: {video_path_on_disk}")
|
437 |
-
except OSError as e_remove:
|
438 |
-
logger.error(f"Task {task_id}: Error cleaning up video file {video_path_on_disk}: {e_remove}")
|
439 |
-
# Consider a separate job for cleaning up frame directories (request_frames_subdir) after a TTL
|
440 |
-
|
441 |
-
|
442 |
-
# --- API Request/Response Models ---
|
443 |
-
class VideoProcessRequest(BaseModel):
|
444 |
-
video_url: Optional[str] = Field(None, description="Publicly accessible URL of the video to process.")
|
445 |
-
num_frames: int = Field(10, gt=0, le=50, description="Number of frames to extract (1-50). Max 50 for performance.") # Reduced max
|
446 |
-
app_base_url: str = Field(..., description="Public base URL of this API service (e.g., https://your-username-your-space-name.hf.space).")
|
447 |
-
|
448 |
-
class TaskCreationResponse(BaseModel):
|
449 |
-
task_id: str
|
450 |
-
status_url: str
|
451 |
-
message: str
|
452 |
-
|
453 |
-
class FrameResult(BaseModel):
|
454 |
-
frame_url: str
|
455 |
-
nsfw_detected: str
|
456 |
-
|
457 |
-
class VideoProcessResult(BaseModel):
|
458 |
-
nsfw_count: int
|
459 |
-
total_frames_analyzed: int
|
460 |
-
frames: List[FrameResult]
|
461 |
-
|
462 |
-
class TaskStatusResponse(BaseModel):
|
463 |
-
task_id: str
|
464 |
-
status: str
|
465 |
-
message: Optional[str] = None
|
466 |
-
result: Optional[VideoProcessResult] = None
|
467 |
-
|
468 |
-
|
469 |
-
# --- API Endpoints ---
|
470 |
-
@app.post("/process_video_async", response_model=TaskCreationResponse, status_code=202)
|
471 |
-
async def process_video_from_url_async_endpoint(
|
472 |
-
request_data: VideoProcessRequest, # Changed from 'request' to avoid conflict with FastAPI's Request object
|
473 |
-
background_tasks: BackgroundTasks
|
474 |
-
):
|
475 |
-
if not request_data.video_url:
|
476 |
-
raise HTTPException(status_code=400, detail="video_url must be provided.")
|
477 |
-
|
478 |
-
task_id = str(uuid.uuid4())
|
479 |
-
tasks_db[task_id] = {"status": "pending", "message": "Task received, preparing for download."}
|
480 |
|
481 |
-
|
482 |
-
|
483 |
-
# Create a temporary file path for the downloaded video
|
484 |
-
# The actual download will also be part of the background task to avoid blocking.
|
485 |
-
# For now, keeping initial download here for simplicity of passing path.
|
486 |
-
# A more robust way: background_tasks.add_task(download_and_then_process, task_id, request_data.video_url, ...)
|
487 |
-
|
488 |
-
# Using a temporary directory specific to this task for the downloaded video
|
489 |
-
task_download_dir = os.path.join(BASE_FRAMES_DIR, "_video_downloads", task_id)
|
490 |
-
os.makedirs(task_download_dir, exist_ok=True)
|
491 |
-
# Suffix from URL or default
|
492 |
-
video_suffix = os.path.splitext(request_data.video_url.split("?")[0])[-1] or ".mp4" # Basic suffix extraction
|
493 |
-
if not video_suffix.startswith("."): video_suffix = "." + video_suffix
|
494 |
|
495 |
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
dl_response = await http_request_with_retry("GET", request_data.video_url, stream=True)
|
501 |
-
if not dl_response or dl_response.status_code != 200:
|
502 |
-
if os.path.exists(task_download_dir): shutil.rmtree(task_download_dir)
|
503 |
-
tasks_db[task_id].update({"status": "failed", "message": f"Failed to download video. Status: {dl_response.status_code if dl_response else 'N/A'}"})
|
504 |
-
raise HTTPException(status_code=400, detail=f"Error downloading video: Status {dl_response.status_code if dl_response else 'N/A'}")
|
505 |
-
|
506 |
-
with open(temp_video_file_path, "wb") as f:
|
507 |
-
for chunk in dl_response.iter_content(chunk_size=8192*4): # Increased chunk size
|
508 |
-
f.write(chunk)
|
509 |
-
logger.info(f"Task {task_id}: Video downloaded to {temp_video_file_path}")
|
510 |
-
|
511 |
-
background_tasks.add_task(process_video_core, task_id, temp_video_file_path, request_data.num_frames, request_data.app_base_url)
|
512 |
-
|
513 |
-
status_url_path = app.url_path_for("get_task_status_endpoint", task_id=task_id)
|
514 |
-
full_status_url = str(request_data.app_base_url.rstrip('/') + status_url_path)
|
515 |
-
|
516 |
-
return TaskCreationResponse(
|
517 |
-
task_id=task_id,
|
518 |
-
status_url=full_status_url,
|
519 |
-
message="Video processing task accepted and started in background."
|
520 |
-
)
|
521 |
-
|
522 |
-
except requests.exceptions.RequestException as e:
|
523 |
-
if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path))
|
524 |
-
tasks_db[task_id].update({"status": "failed", "message": f"Video download error: {e}"})
|
525 |
-
raise HTTPException(status_code=400, detail=f"Error downloading video: {e}")
|
526 |
-
except Exception as e:
|
527 |
-
if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path))
|
528 |
-
logger.error(f"Task {task_id}: Unexpected error during task submission: {e}", exc_info=True)
|
529 |
-
tasks_db[task_id].update({"status": "failed", "message": "Internal server error during task submission."})
|
530 |
-
raise HTTPException(status_code=500, detail="Internal server error during task submission.")
|
531 |
-
|
532 |
-
|
533 |
-
@app.post("/upload_video_async", response_model=TaskCreationResponse, status_code=202)
|
534 |
-
async def upload_video_async_endpoint(
|
535 |
-
background_tasks: BackgroundTasks,
|
536 |
-
app_base_url: str = Form(..., description="Public base URL of this API service."),
|
537 |
-
num_frames: int = Form(10, gt=0, le=50, description="Number of frames to extract (1-50)."),
|
538 |
-
video_file: UploadFile = File(..., description="Video file to upload and process.")
|
539 |
-
):
|
540 |
-
if not video_file.content_type or not video_file.content_type.startswith("video/"):
|
541 |
-
raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.")
|
542 |
-
|
543 |
-
task_id = str(uuid.uuid4())
|
544 |
-
tasks_db[task_id] = {"status": "pending", "message": "Task received, saving uploaded video."}
|
545 |
-
temp_video_file_path = None
|
546 |
-
try:
|
547 |
-
upload_dir = os.path.join(BASE_FRAMES_DIR, "_video_uploads", task_id) # Task-specific upload dir
|
548 |
-
os.makedirs(upload_dir, exist_ok=True)
|
549 |
|
550 |
-
|
551 |
-
if
|
552 |
-
|
553 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
return TaskCreationResponse(
|
564 |
-
task_id=task_id,
|
565 |
-
status_url=full_status_url,
|
566 |
-
message="Video upload accepted and processing started in background."
|
567 |
-
)
|
568 |
-
except Exception as e:
|
569 |
-
if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path))
|
570 |
-
logger.error(f"Task {task_id}: Error handling video upload: {e}", exc_info=True)
|
571 |
-
tasks_db[task_id].update({"status": "failed", "message": "Internal server error during video upload."})
|
572 |
-
raise HTTPException(status_code=500, detail=f"Error processing uploaded file: {e}")
|
573 |
-
finally:
|
574 |
-
if video_file:
|
575 |
-
await video_file.close()
|
576 |
-
|
577 |
-
|
578 |
-
@app.get("/tasks/{task_id}/status", response_model=TaskStatusResponse)
|
579 |
-
async def get_task_status_endpoint(task_id: str):
|
580 |
-
task = tasks_db.get(task_id)
|
581 |
-
if not task:
|
582 |
-
raise HTTPException(status_code=404, detail="Task not found.")
|
583 |
-
return TaskStatusResponse(task_id=task_id, **task)
|
584 |
-
|
585 |
-
# --- Homepage Endpoint ---
|
586 |
-
@app.get("/", response_class=HTMLResponse)
|
587 |
-
async def read_root(fastapi_request: Request): # Renamed from 'request' to avoid conflict
|
588 |
-
# Try to determine app_base_url automatically if possible (might be tricky behind proxies)
|
589 |
-
# For Hugging Face, the X-Forwarded-Host or similar headers might be useful.
|
590 |
-
# A simpler approach for HF is to have the user provide it or construct it.
|
591 |
-
# For the example curl, let's use a placeholder.
|
592 |
|
593 |
-
#
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
|
|
|
|
|
|
601 |
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
<body>
|
640 |
-
<div class="container">
|
641 |
-
<h1>NSFW Video Detector API</h1>
|
642 |
-
<p>This API allows you to process videos to detect Not Suitable For Work (NSFW) content. It works asynchronously: you submit a video (via URL or direct upload), receive a task ID, and then poll a status endpoint to get the results.</p>
|
643 |
-
|
644 |
-
<div class="note">
|
645 |
-
<p><span class="param">Important:</span> The <code>app_base_url</code> parameter is crucial. It must be the public base URL where this API service is accessible. For example, if your Hugging Face Space URL is <code>https://your-username-your-space-name.hf.space</code>, then that's your <code>app_base_url</code>.</p>
|
646 |
-
<p>Current detected example base URL for instructions: <code>{example_app_base_url}</code> (This is a guess, please verify your actual public URL).</p>
|
647 |
-
</div>
|
648 |
-
|
649 |
-
<h2>Endpoints</h2>
|
650 |
-
|
651 |
-
<div class="endpoint">
|
652 |
-
<h3>1. Process Video from URL (Asynchronous)</h3>
|
653 |
-
<p><code>POST /process_video_async</code></p>
|
654 |
-
<p>Submits a video from a public URL for NSFW analysis.</p>
|
655 |
-
<h4>Request Body (JSON):</h4>
|
656 |
-
<table>
|
657 |
-
<tr><th>Parameter</th><th>Type</th><th>Default</th><th>Description</th></tr>
|
658 |
-
<tr><td><span class="param">video_url</span></td><td>string</td><td><em>Required</em></td><td>Publicly accessible URL of the video.</td></tr>
|
659 |
-
<tr><td><span class="param">num_frames</span></td><td>integer</td><td>10</td><td>Number of frames to extract and analyze (1-50).</td></tr>
|
660 |
-
<tr><td><span class="param">app_base_url</span></td><td>string</td><td><em>Required</em></td><td>The public base URL of this API service.</td></tr>
|
661 |
-
</table>
|
662 |
-
<h4>Example using <code>curl</code>:</h4>
|
663 |
-
<pre><code>curl -X POST "{example_app_base_url}/process_video_async" \\
|
664 |
-
-H "Content-Type: application/json" \\
|
665 |
-
-d '{{
|
666 |
-
"video_url": "YOUR_PUBLIC_VIDEO_URL_HERE.mp4",
|
667 |
-
"num_frames": 5,
|
668 |
-
"app_base_url": "{example_app_base_url}"
|
669 |
-
}}'</code></pre>
|
670 |
-
</div>
|
671 |
-
|
672 |
-
<div class="endpoint">
|
673 |
-
<h3>2. Upload Video File (Asynchronous)</h3>
|
674 |
-
<p><code>POST /upload_video_async</code></p>
|
675 |
-
<p>Uploads a video file directly for NSFW analysis.</p>
|
676 |
-
<h4>Request Body (Multipart Form-Data):</h4>
|
677 |
-
<table>
|
678 |
-
<tr><th>Parameter</th><th>Type</th><th>Default</th><th>Description</th></tr>
|
679 |
-
<tr><td><span class="param">video_file</span></td><td>file</td><td><em>Required</em></td><td>The video file to upload.</td></tr>
|
680 |
-
<tr><td><span class="param">num_frames</span></td><td>integer</td><td>10</td><td>Number of frames to extract (1-50).</td></tr>
|
681 |
-
<tr><td><span class="param">app_base_url</span></td><td>string</td><td><em>Required</em></td><td>The public base URL of this API service.</td></tr>
|
682 |
-
</table>
|
683 |
-
<h4>Example using <code>curl</code>:</h4>
|
684 |
-
<pre><code>curl -X POST "{example_app_base_url}/upload_video_async" \\
|
685 |
-
-F "video_file=@/path/to/your/video.mp4" \\
|
686 |
-
-F "num_frames=5" \\
|
687 |
-
-F "app_base_url={example_app_base_url}"</code></pre>
|
688 |
-
</div>
|
689 |
-
|
690 |
-
<div class="tip">
|
691 |
-
<h4>Response for Task Creation (for both URL and Upload):</h4>
|
692 |
-
<p>If successful (HTTP 202 Accepted), the API will respond with:</p>
|
693 |
-
<pre><code>{{
|
694 |
-
"task_id": "some-unique-task-id",
|
695 |
-
"status_url": "{example_app_base_url}/tasks/some-unique-task-id/status",
|
696 |
-
"message": "Video processing task accepted and started in background."
|
697 |
-
}}</code></pre>
|
698 |
-
</div>
|
699 |
-
|
700 |
-
<div class="endpoint">
|
701 |
-
<h3>3. Get Task Status and Result</h3>
|
702 |
-
<p><code>GET /tasks/<task_id>/status</code></p>
|
703 |
-
<p>Poll this endpoint to check the status of a processing task and retrieve the result once completed.</p>
|
704 |
-
<h4>Example using <code>curl</code>:</h4>
|
705 |
-
<pre><code>curl -X GET "{example_app_base_url}/tasks/some-unique-task-id/status"</code></pre>
|
706 |
-
<h4>Possible Statuses:</h4>
|
707 |
-
<ul>
|
708 |
-
<li><code>pending</code>: Task is queued.</li>
|
709 |
-
<li><code>processing</code>: Task is actively being processed (downloading, extracting frames, analyzing).</li>
|
710 |
-
<li><code>completed</code>: Task finished successfully. Results are available in the <code>result</code> field.</li>
|
711 |
-
<li><code>failed</code>: Task failed. Check the <code>message</code> field for details.</li>
|
712 |
-
</ul>
|
713 |
-
<h4>Example Response (Status: <code>completed</code>):</h4>
|
714 |
-
<pre><code>{{
|
715 |
-
"task_id": "some-unique-task-id",
|
716 |
-
"status": "completed",
|
717 |
-
"message": "Processing complete.",
|
718 |
-
"result": {{
|
719 |
-
"nsfw_count": 1,
|
720 |
-
"total_frames_analyzed": 5,
|
721 |
-
"frames": [
|
722 |
-
{{
|
723 |
-
"frame_url": "{example_app_base_url}/static_frames/some-unique-task-id/frame_uuid1.jpg",
|
724 |
-
"nsfw_detected": "false"
|
725 |
-
}},
|
726 |
-
{{
|
727 |
-
"frame_url": "{example_app_base_url}/static_frames/some-unique-task-id/frame_uuid2.jpg",
|
728 |
-
"nsfw_detected": "true"
|
729 |
-
}}
|
730 |
-
// ... more frames
|
731 |
-
]
|
732 |
-
}}
|
733 |
-
}}</code></pre>
|
734 |
-
<h4>Example Response (Status: <code>processing</code>):</h4>
|
735 |
-
<pre><code>{{
|
736 |
-
"task_id": "some-unique-task-id",
|
737 |
-
"status": "processing",
|
738 |
-
"message": "Analyzing 5 frames...",
|
739 |
-
"result": null
|
740 |
-
}}</code></pre>
|
741 |
-
</div>
|
742 |
-
<p style="text-align:center; margin-top:30px; font-size:0.9em; color:#777;">API Version: {app.version}</p>
|
743 |
-
</div>
|
744 |
-
</body>
|
745 |
-
</html>
|
746 |
-
"""
|
747 |
-
# If using Jinja2 templates:
|
748 |
-
# if templates:
|
749 |
-
# return templates.TemplateResponse("index.html", {"request": fastapi_request, "app_version": app.version, "example_app_base_url": example_app_base_url})
|
750 |
-
# else:
|
751 |
-
# return HTMLResponse(content=html_content, status_code=200)
|
752 |
-
return HTMLResponse(content=html_content, status_code=200)
|
753 |
-
|
754 |
-
|
755 |
-
# Example of how to run for local development:
|
756 |
-
# 1. Ensure you have a 'templates/index.html' file or the fallback HTML will be used.
|
757 |
-
# 2. Run: uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
758 |
-
# (assuming your file is named main.py)
|
759 |
-
# Requirements: fastapi uvicorn[standard] opencv-python requests pydantic python-multipart (for Form/File uploads)
|
|
|
1 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
import cv2 # OpenCV for video processing
|
|
|
3 |
import os # For interacting with the file system
|
4 |
import requests # For making HTTP requests
|
5 |
import random
|
|
|
7 |
import json
|
8 |
import shutil # For file operations
|
9 |
import ast # For safely evaluating string literals
|
10 |
+
import uuid # For generating unique filenames
|
11 |
import asyncio # For concurrent operations
|
12 |
import time # For retries and delays
|
13 |
import logging # For structured logging
|
14 |
|
|
|
|
|
|
|
15 |
# --- Logging Configuration ---
|
16 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
19 |
+
# --- Configuration ---
|
20 |
+
# Directory for temporarily storing extracted frames.
|
21 |
+
GRADIO_TEMP_FRAME_DIR = "/tmp/gradio_nsfw_frames_advanced"
|
22 |
+
os.makedirs(GRADIO_TEMP_FRAME_DIR, exist_ok=True)
|
23 |
+
|
24 |
+
# The public URL of this Gradio Space. Crucial for external NSFW checkers.
|
25 |
+
# Set via environment variable or update placeholder if hardcoding.
|
26 |
+
APP_BASE_URL = os.getenv("APP_BASE_URL", "YOUR_GRADIO_SPACE_PUBLIC_URL_HERE")
|
27 |
+
|
28 |
+
DEFAULT_REQUEST_TIMEOUT = 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
MAX_RETRY_ATTEMPTS = 3
|
30 |
+
RETRY_BACKOFF_FACTOR = 2
|
31 |
|
32 |
+
# --- NSFW Checker Configuration (from FastAPI version) ---
|
33 |
NSFW_CHECKER_CONFIG = {
|
34 |
"checker1_yoinked": {
|
35 |
"queue_join_url": "https://yoinked-da-nsfw-checker.hf.space/queue/join",
|
|
|
62 |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 58
|
63 |
}
|
64 |
},
|
65 |
+
"checker5_phelpsgg": { # Original user code had 'imseldrith', FastAPI had 'phelpsgg'. Using phelpsgg from FastAPI.
|
66 |
+
"queue_join_url": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/join",
|
67 |
+
"queue_data_url_template": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}",
|
68 |
"payload_template": lambda img_url, session_hash: {
|
69 |
'data': [{'path': img_url}],
|
70 |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9
|
|
|
72 |
}
|
73 |
}
|
74 |
|
75 |
+
# --- Helper Functions (from FastAPI version, adapted for Gradio context) ---
|
|
|
|
|
|
|
76 |
async def http_request_with_retry(method: str, url: str, **kwargs) -> Optional[requests.Response]:
|
|
|
77 |
headers = kwargs.pop("headers", {})
|
78 |
+
headers.setdefault("User-Agent", "GradioNSFWClient/1.0")
|
79 |
|
80 |
for attempt in range(MAX_RETRY_ATTEMPTS):
|
81 |
try:
|
82 |
async with asyncio.Semaphore(10):
|
83 |
loop = asyncio.get_event_loop()
|
|
|
84 |
response = await loop.run_in_executor(
|
85 |
None,
|
86 |
lambda: requests.request(method, url, headers=headers, timeout=DEFAULT_REQUEST_TIMEOUT, **kwargs)
|
|
|
107 |
|
108 |
def get_replica_code_sync(url: str) -> Optional[str]:
|
109 |
try:
|
110 |
+
r = requests.get(url, timeout=DEFAULT_REQUEST_TIMEOUT, headers={"User-Agent": "GradioNSFWClient/1.0"})
|
111 |
r.raise_for_status()
|
|
|
112 |
parts = r.text.split('replicas/')
|
113 |
if len(parts) > 1:
|
114 |
return parts[1].split('"};')[0]
|
|
|
122 |
loop = asyncio.get_event_loop()
|
123 |
return await loop.run_in_executor(None, get_replica_code_sync, url)
|
124 |
|
|
|
125 |
async def parse_hf_queue_response(response_content: str) -> Optional[str]:
|
126 |
try:
|
127 |
messages = response_content.strip().split('\n')
|
|
|
161 |
payload = config["payload_template"](img_url)
|
162 |
response_predict = await http_request_with_retry("POST", config["predict_url"], json=payload)
|
163 |
if not response_predict or response_predict.status_code != 200:
|
164 |
+
logger.error(f"{checker_name} predict call failed. Status: {response_predict.status_code if response_predict else 'N/A'}")
|
165 |
return None
|
166 |
|
167 |
json_data = response_predict.json()
|
|
|
172 |
|
173 |
event_url = config["event_url_template"].format(event_id=event_id)
|
174 |
for _ in range(10):
|
175 |
+
await asyncio.sleep(random.uniform(1.5, 2.5))
|
176 |
+
response_event = await http_request_with_retry("GET", event_url) # Removed stream=True as iter_content not used directly
|
177 |
if response_event and response_event.status_code == 200:
|
178 |
+
event_stream_content = response_event.text
|
179 |
if 'data:' in event_stream_content:
|
180 |
final_data_str = event_stream_content.strip().split('data:')[-1].strip()
|
181 |
if final_data_str:
|
182 |
try:
|
183 |
parsed_list = ast.literal_eval(final_data_str)
|
184 |
+
if isinstance(parsed_list, list) and parsed_list and isinstance(parsed_list[0], dict):
|
185 |
return parsed_list[0].get('label')
|
186 |
+
logger.warning(f"{checker_name} parsed non-list or empty list from event stream: {final_data_str[:100]}")
|
187 |
except (SyntaxError, ValueError, IndexError, TypeError) as e:
|
188 |
logger.warning(f"{checker_name} error parsing event stream: {e}, stream: {final_data_str[:100]}")
|
189 |
elif response_event:
|
|
|
217 |
return None
|
218 |
|
219 |
for _ in range(15):
|
220 |
+
await asyncio.sleep(random.uniform(1.5, 2.5))
|
221 |
+
response_data = await http_request_with_retry("GET", data_url, stream=True)
|
222 |
if response_data and response_data.status_code == 200:
|
223 |
buffer = ""
|
224 |
+
# iter_content is synchronous, but http_request_with_retry runs it in executor
|
225 |
+
for content_chunk in response_data.iter_content(chunk_size=1024, decode_unicode=True):
|
226 |
if content_chunk:
|
227 |
buffer += content_chunk
|
228 |
+
if buffer.strip().endswith("}\n\n"):
|
229 |
+
label = await parse_hf_queue_response(buffer) # parse_hf_queue_response is async
|
230 |
if label: return label
|
231 |
+
buffer = ""
|
232 |
elif response_data:
|
233 |
logger.warning(f"{checker_name} polling queue/data returned status {response_data.status_code}")
|
234 |
else:
|
|
|
243 |
|
244 |
async def check_nsfw_final_concurrent(img_url: str) -> Optional[bool]:
|
245 |
logger.info(f"Starting NSFW check for: {img_url}")
|
246 |
+
# Prioritized list from FastAPI version
|
247 |
checker_names = [
|
248 |
"checker2_jamescookjr90", "checker3_zanderlewis", "checker5_phelpsgg",
|
249 |
"checker4_error466", "checker1_yoinked"
|
250 |
]
|
251 |
|
|
|
252 |
named_tasks = {
|
253 |
name: asyncio.create_task(check_nsfw_single_generic(name, img_url))
|
254 |
for name in checker_names
|
255 |
}
|
256 |
|
|
|
257 |
sfw_found_by_any_checker = False
|
258 |
|
259 |
+
# Iterate and await tasks. Since as_completed is not used, order of results depends on await order.
|
260 |
+
# For true "first result wins" or concurrent processing, as_completed or gather is better.
|
261 |
+
# This simplified loop awaits them one by one based on checker_names order.
|
262 |
+
for task_name in checker_names:
|
263 |
try:
|
264 |
+
label = await named_tasks[task_name]
|
265 |
logger.info(f"Checker '{task_name}' result for {img_url}: {label}")
|
266 |
if label:
|
267 |
label_lower = label.lower()
|
268 |
if 'nsfw' in label_lower:
|
269 |
logger.info(f"NSFW detected by '{task_name}' for {img_url}. Final: True.")
|
270 |
+
# Cancel remaining tasks
|
271 |
+
for t_name_to_cancel, t_obj_to_cancel in named_tasks.items():
|
272 |
+
if t_name_to_cancel != task_name and not t_obj_to_cancel.done():
|
273 |
+
t_obj_to_cancel.cancel()
|
274 |
return True
|
275 |
if 'sfw' in label_lower or 'safe' in label_lower:
|
276 |
sfw_found_by_any_checker = True
|
277 |
+
|
|
|
278 |
except asyncio.CancelledError:
|
279 |
logger.info(f"Checker '{task_name}' was cancelled for {img_url}.")
|
280 |
except Exception as e:
|
281 |
logger.error(f"Error processing result from checker '{task_name}' for {img_url}: {e}")
|
282 |
|
283 |
+
if sfw_found_by_any_checker:
|
284 |
+
logger.info(f"SFW confirmed for {img_url} (no NSFW detected by any checker, at least one SFW). Final: False.")
|
285 |
return False
|
286 |
|
287 |
logger.warning(f"All NSFW checkers inconclusive or failed for {img_url}. Final: None.")
|
288 |
return None
|
289 |
|
290 |
+
# --- Gradio Specific Functions ---
|
291 |
+
def extract_frames_sync(video_path: str, num_frames_to_extract: int, progress:gr.Progress=None) -> list:
|
292 |
+
if progress: progress(0, desc="Starting frame extraction...")
|
293 |
+
|
|
|
|
|
|
|
294 |
vidcap = cv2.VideoCapture(video_path)
|
295 |
if not vidcap.isOpened():
|
296 |
+
logger.error(f"Error: Cannot open video file {video_path}")
|
297 |
return []
|
298 |
+
|
299 |
total_frames_in_video = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
300 |
+
extracted_filenames = [] # Stores only filenames, not full paths
|
301 |
|
302 |
if total_frames_in_video == 0:
|
|
|
303 |
vidcap.release()
|
304 |
return []
|
305 |
|
|
|
306 |
actual_frames_to_extract = min(num_frames_to_extract, total_frames_in_video)
|
307 |
+
if actual_frames_to_extract == 0 and total_frames_in_video > 0: actual_frames_to_extract = 1
|
308 |
+
if actual_frames_to_extract == 0:
|
|
|
|
|
309 |
vidcap.release()
|
310 |
return []
|
311 |
|
|
|
312 |
for i in range(actual_frames_to_extract):
|
313 |
+
if progress: progress(i / actual_frames_to_extract, desc=f"Extracting frame {i+1}/{actual_frames_to_extract}")
|
314 |
+
|
315 |
+
frame_number = int(i * total_frames_in_video / actual_frames_to_extract)
|
316 |
+
frame_number = min(frame_number, total_frames_in_video - 1)
|
317 |
|
318 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
319 |
success, image = vidcap.read()
|
320 |
if success:
|
321 |
+
random_filename = f"{uuid.uuid4().hex}.jpg"
|
322 |
+
full_frame_path = os.path.join(GRADIO_TEMP_FRAME_DIR, random_filename)
|
323 |
+
try:
|
324 |
+
cv2.imwrite(full_frame_path, image)
|
325 |
+
extracted_filenames.append(random_filename)
|
326 |
+
except Exception as e:
|
327 |
+
logger.error(f"Error writing frame {full_frame_path}: {e}")
|
328 |
else:
|
329 |
+
logger.warning(f"Warning: Failed to read frame at position {frame_number} from {video_path}")
|
330 |
+
|
331 |
vidcap.release()
|
332 |
+
if progress: progress(1, desc="Frame extraction complete.")
|
333 |
+
return extracted_filenames
|
334 |
|
335 |
+
async def process_video_gradio(video_temp_path: str, num_frames: int, progress=gr.Progress(track_tqdm=True)):
|
336 |
+
"""
|
337 |
+
Main async processing function for Gradio, using robust NSFW checkers.
|
338 |
+
"""
|
339 |
+
if APP_BASE_URL == "YOUR_GRADIO_SPACE_PUBLIC_URL_HERE":
|
340 |
+
warning_msg = "CRITICAL WARNING: APP_BASE_URL is not set! External NSFW checks will likely fail. Please set the APP_BASE_URL environment variable to your public Gradio Space URL."
|
341 |
+
logger.error(warning_msg)
|
342 |
+
# Optionally, display this warning in the Gradio UI as well
|
343 |
+
# For now, returning it in the JSON output.
|
344 |
+
return {"error": warning_msg, "details": "The application needs to know its own public URL to construct frame URLs for external analysis services."}
|
345 |
|
346 |
+
|
347 |
+
if not video_temp_path:
|
348 |
+
return {"error": "No video file provided or video path is invalid."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
+
try:
|
351 |
+
num_frames = int(num_frames)
|
352 |
+
if num_frames <= 0:
|
353 |
+
return {"error": "Number of frames must be a positive integer."}
|
354 |
+
except (ValueError, TypeError):
|
355 |
+
return {"error": "Invalid number for frames."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
+
# Run synchronous frame extraction in an executor to keep the async event loop unblocked
|
358 |
+
loop = asyncio.get_event_loop()
|
359 |
+
extracted_frame_filenames = await loop.run_in_executor(
|
360 |
+
None, extract_frames_sync, video_temp_path, num_frames, progress
|
361 |
+
)
|
362 |
+
|
363 |
+
if not extracted_frame_filenames:
|
364 |
+
# Cleanup the uploaded video file if it exists and extraction failed
|
365 |
+
if os.path.exists(video_temp_path):
|
366 |
+
try: os.remove(video_temp_path)
|
367 |
+
except Exception as e: logger.error(f"Error cleaning up input video {video_temp_path} after failed extraction: {e}")
|
368 |
+
return {"error": "Could not extract any frames from the video."}
|
369 |
+
|
370 |
+
nsfw_count = 0
|
371 |
+
total_frames_processed = len(extracted_frame_filenames)
|
372 |
+
frame_results_output = []
|
373 |
+
|
374 |
+
analysis_coroutines = []
|
375 |
+
for frame_filename in extracted_frame_filenames:
|
376 |
+
# Construct the URL for the NSFW checker using Gradio's /file= route
|
377 |
+
# The path for /file= should be the absolute path on the server where Gradio can find the file.
|
378 |
+
absolute_frame_path_on_server = os.path.join(GRADIO_TEMP_FRAME_DIR, frame_filename)
|
379 |
+
publicly_accessible_frame_url = f"{APP_BASE_URL.rstrip('/')}/file={absolute_frame_path_on_server}"
|
380 |
+
analysis_coroutines.append(check_nsfw_final_concurrent(publicly_accessible_frame_url))
|
381 |
+
|
382 |
+
# Update progress for analysis phase
|
383 |
+
# Since gr.Progress doesn't directly map to asyncio.gather, we'll set a general message.
|
384 |
+
if progress: progress(0.5, desc=f"Analyzing {total_frames_processed} frames (may take time)...")
|
385 |
|
386 |
+
nsfw_detection_results = await asyncio.gather(*analysis_coroutines, return_exceptions=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
+
# Update progress after analysis
|
389 |
+
if progress: progress(0.9, desc="Compiling results...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
|
391 |
|
392 |
+
for i, detection_result in enumerate(nsfw_detection_results):
|
393 |
+
frame_filename = extracted_frame_filenames[i]
|
394 |
+
absolute_frame_path_on_server = os.path.join(GRADIO_TEMP_FRAME_DIR, frame_filename)
|
395 |
+
publicly_accessible_frame_url = f"{APP_BASE_URL.rstrip('/')}/file={absolute_frame_path_on_server}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
+
is_nsfw_str = "unknown"
|
398 |
+
if isinstance(detection_result, Exception):
|
399 |
+
logger.error(f"Error analyzing frame {publicly_accessible_frame_url}: {detection_result}")
|
400 |
+
is_nsfw_str = "error"
|
401 |
+
else: # detection_result is True, False, or None
|
402 |
+
if detection_result is True:
|
403 |
+
nsfw_count += 1
|
404 |
+
is_nsfw_str = "true"
|
405 |
+
elif detection_result is False:
|
406 |
+
is_nsfw_str = "false"
|
407 |
+
|
408 |
+
frame_results_output.append({
|
409 |
+
"frame_filename_on_server": frame_filename,
|
410 |
+
"checked_url": publicly_accessible_frame_url,
|
411 |
+
"nsfw_detected": is_nsfw_str
|
412 |
+
})
|
413 |
+
|
414 |
+
if progress: progress(1, desc="Analysis complete. Cleaning up temporary files...")
|
415 |
|
416 |
+
# Cleanup extracted frames
|
417 |
+
for frame_filename in extracted_frame_filenames:
|
418 |
+
full_frame_path_to_delete = os.path.join(GRADIO_TEMP_FRAME_DIR, frame_filename)
|
419 |
+
if os.path.exists(full_frame_path_to_delete):
|
420 |
+
try:
|
421 |
+
os.remove(full_frame_path_to_delete)
|
422 |
+
except Exception as e:
|
423 |
+
logger.error(f"Error deleting frame {full_frame_path_to_delete}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
+
# Gradio manages the `video_temp_path` (uploaded video) cleanup.
|
426 |
+
|
427 |
+
final_result_json = {
|
428 |
+
"summary": {
|
429 |
+
"nsfw_frames_found": nsfw_count,
|
430 |
+
"total_frames_analyzed": total_frames_processed,
|
431 |
+
"app_base_url_used_for_checks": APP_BASE_URL,
|
432 |
+
"frames_temp_dir_on_server": GRADIO_TEMP_FRAME_DIR
|
433 |
+
},
|
434 |
+
"frame_details": frame_results_output
|
435 |
+
}
|
436 |
|
437 |
+
return final_result_json
|
438 |
+
|
439 |
+
# --- Gradio Interface Definition ---
|
440 |
+
with gr.Blocks(css="footer {display: none !important;}", title="NSFW Video Detector") as app_interface:
|
441 |
+
gr.Markdown(
|
442 |
+
f"""
|
443 |
+
# NSFW Frame Detection in Video
|
444 |
+
Upload a video and specify the number of frames to check for NSFW content.
|
445 |
+
The analysis uses a series of external NSFW detection models with retries and concurrent checks for robustness.
|
446 |
+
"""
|
447 |
+
)
|
448 |
+
with gr.Accordion("Important: How this Space Works & Configuration", open=False):
|
449 |
+
gr.Markdown(
|
450 |
+
f"""
|
451 |
+
- **`APP_BASE_URL`**: For this Space to work correctly when deployed (e.g., on Hugging Face Spaces), the `APP_BASE_URL` environment variable **must** be set to its public URL (e.g., `https://your-username-your-spacename.hf.space`). This is because external NSFW checkers need to access the extracted video frames via public URLs.
|
452 |
+
- **Currently configured `APP_BASE_URL` for frame checking: `{APP_BASE_URL}`**. If this shows the placeholder, the NSFW checks will likely fail.
|
453 |
+
- **Temporary Frame Storage**: Frames are temporarily extracted to `{GRADIO_TEMP_FRAME_DIR}` on the server and deleted after processing.
|
454 |
+
- **Processing Time**: Depends on video length, number of frames selected, and the responsiveness of external NSFW checking services. Please be patient.
|
455 |
+
"""
|
456 |
+
)
|
457 |
+
|
458 |
+
with gr.Row():
|
459 |
+
with gr.Column(scale=1):
|
460 |
+
video_input = gr.Video(label="Upload Video")
|
461 |
+
num_frames_input = gr.Number(label="Number of Frames to Check (e.g., 5-20)", value=10, minimum=1, maximum=50, step=1, precision=0)
|
462 |
+
submit_button = gr.Button("Detect NSFW Content", variant="primary")
|
463 |
+
with gr.Column(scale=2):
|
464 |
+
output_json = gr.JSON(label="Detection Result")
|
465 |
|
466 |
+
submit_button.click(
|
467 |
+
fn=process_video_gradio, # Use the new async handler
|
468 |
+
inputs=[video_input, num_frames_input],
|
469 |
+
outputs=output_json
|
470 |
+
)
|
471 |
+
|
472 |
+
gr.Examples(
|
473 |
+
examples=[
|
474 |
+
# Provide path to a sample video if you have one in your Space repo
|
475 |
+
# Example: [os.path.join(os.path.dirname(__file__), "sample_video.mp4"), 5],
|
476 |
+
],
|
477 |
+
inputs=[video_input, num_frames_input],
|
478 |
+
outputs=output_json,
|
479 |
+
fn=process_video_gradio, # Ensure example uses the async handler too
|
480 |
+
cache_examples=False
|
481 |
+
)
|
482 |
+
gr.Markdown("Note: If `APP_BASE_URL` is not correctly set to this Space's public URL, the NSFW detection for frames will fail as external services won't be able to access them.")
|
483 |
+
|
484 |
+
if __name__ == "__main__":
|
485 |
+
if APP_BASE_URL == "YOUR_GRADIO_SPACE_PUBLIC_URL_HERE": # Check placeholder
|
486 |
+
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
487 |
+
print("!!! CRITICAL WARNING: APP_BASE_URL is NOT SET or using a placeholder. !!!")
|
488 |
+
print("!!! External NSFW checks will likely FAIL. !!!")
|
489 |
+
print("!!! For local testing: Expose this app (e.g., with ngrok) and set APP_BASE_URL to the ngrok URL. !!!")
|
490 |
+
print("!!! When deploying to Hugging Face Spaces, set the APP_BASE_URL environment variable in Space settings.!!!")
|
491 |
+
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
492 |
+
|
493 |
+
# For Gradio, app.launch() is typically how you start it.
|
494 |
+
# If running in a Hugging Face Space with a Dockerfile, the CMD usually is `python app.py`.
|
495 |
+
# Gradio handles the web server part.
|
496 |
+
app_interface.queue() # Enable queue for handling multiple requests and longer tasks better.
|
497 |
+
app_interface.launch() # share=True can be used for temporary public link if running locally.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|