imseldrith commited on
Commit
e4daf0b
·
verified ·
1 Parent(s): 26d31c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +759 -0
app.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks, Depends, Request
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import HTMLResponse, JSONResponse
4
+ from fastapi.templating import Jinja2Templates # For serving HTML
5
+ from pydantic import BaseModel, Field
6
+ from typing import List, Dict, Optional, Union
7
+ import cv2 # OpenCV for video processing
8
+ import uuid # For generating unique filenames
9
+ import os # For interacting with the file system
10
+ import requests # For making HTTP requests
11
+ import random
12
+ import string
13
+ import json
14
+ import shutil # For file operations
15
+ import ast # For safely evaluating string literals
16
+ import tempfile # For creating temporary directories/files
17
+ import asyncio # For concurrent operations
18
+ import time # For retries and delays
19
+ import logging # For structured logging
20
+
21
+ # --- Application Setup ---
22
+ app = FastAPI(title="Advanced NSFW Video Detector API", version="1.1.0") # Updated version
23
+
24
+ # --- Logging Configuration ---
25
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # --- Templates for HTML Homepage ---
29
+ # Create a 'templates' directory in the same location as your main.py
30
+ # and put an 'index.html' file inside it.
31
+ # For Hugging Face Spaces, you might need to adjust path or ensure the templates dir is included.
32
+ # For simplicity here, I'll embed the HTML string directly if Jinja2 setup is complex for the environment.
33
+ # However, using Jinja2 is cleaner. Let's assume a 'templates' directory.
34
+ # If 'templates' dir doesn't exist, it will fall back to a basic HTML string.
35
+ try:
36
+ templates_path = os.path.join(os.path.dirname(__file__), "templates")
37
+ if not os.path.exists(templates_path):
38
+ os.makedirs(templates_path) # Create if not exists for local dev
39
+ templates = Jinja2Templates(directory=templates_path)
40
+ # Create a dummy index.html if it doesn't exist for local testing
41
+ dummy_html_path = os.path.join(templates_path, "index.html")
42
+ if not os.path.exists(dummy_html_path):
43
+ with open(dummy_html_path, "w") as f:
44
+ f.write("<h1>Dummy Index Page - Replace with actual instructions</h1>")
45
+ except Exception as e:
46
+ logger.warning(f"Jinja2Templates initialization failed: {e}. Will use basic HTML string for homepage.")
47
+ templates = None
48
+
49
+
50
+ # --- Configuration (Potentially from environment variables or a settings file) ---
51
+ DEFAULT_REQUEST_TIMEOUT = 20 # Increased timeout for individual NSFW checker requests
52
+ MAX_RETRY_ATTEMPTS = 3
53
+ RETRY_BACKOFF_FACTOR = 2 # In seconds
54
+
55
+ # --- NSFW Checker URLs (Ideally, these would be in a config) ---
56
+ NSFW_CHECKER_CONFIG = {
57
+ "checker1_yoinked": {
58
+ "queue_join_url": "https://yoinked-da-nsfw-checker.hf.space/queue/join",
59
+ "queue_data_url_template": "https://yoinked-da-nsfw-checker.hf.space/queue/data?session_hash={session_hash}",
60
+ "payload_template": lambda img_url, session_hash: {
61
+ 'data': [{'path': img_url}, "chen-convnext", 0.5, True, True],
62
+ 'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 12
63
+ }
64
+ },
65
+ "checker2_jamescookjr90": {
66
+ "queue_join_url": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/join",
67
+ "queue_data_url_template": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}",
68
+ "payload_template": lambda img_url, session_hash: {
69
+ 'data': [{'path': img_url}],
70
+ 'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9
71
+ }
72
+ },
73
+ "checker3_zanderlewis": {
74
+ "predict_url": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict",
75
+ "event_url_template": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict/{event_id}",
76
+ "payload_template": lambda img_url: {'data': [{'path': img_url}]}
77
+ },
78
+ "checker4_error466": {
79
+ "base_url": "https://error466-falconsai-nsfw-image-detection.hf.space",
80
+ "replica_code_needed": True,
81
+ "queue_join_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/join",
82
+ "queue_data_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/data?session_hash={session_hash}",
83
+ "payload_template": lambda img_url, session_hash: {
84
+ 'data': [{'path': img_url}],
85
+ 'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 58
86
+ }
87
+ },
88
+ "checker5_phelpsgg": {
89
+ "queue_join_url": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/join",
90
+ "queue_data_url_template": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}",
91
+ "payload_template": lambda img_url, session_hash: {
92
+ 'data': [{'path': img_url}],
93
+ 'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9
94
+ }
95
+ }
96
+ }
97
+
98
+ # --- Task Management for Asynchronous Processing ---
99
+ tasks_db: Dict[str, Dict] = {}
100
+
101
+ # --- Helper Functions ---
102
+ async def http_request_with_retry(method: str, url: str, **kwargs) -> Optional[requests.Response]:
103
+ """Makes an HTTP request with retries, exponential backoff, and jitter."""
104
+ headers = kwargs.pop("headers", {})
105
+ headers.setdefault("User-Agent", "NSFWDetectorClient/1.1")
106
+
107
+ for attempt in range(MAX_RETRY_ATTEMPTS):
108
+ try:
109
+ async with asyncio.Semaphore(10):
110
+ loop = asyncio.get_event_loop()
111
+ # For requests library, which is synchronous
112
+ response = await loop.run_in_executor(
113
+ None,
114
+ lambda: requests.request(method, url, headers=headers, timeout=DEFAULT_REQUEST_TIMEOUT, **kwargs)
115
+ )
116
+ response.raise_for_status()
117
+ return response
118
+ except requests.exceptions.Timeout:
119
+ logger.warning(f"Request timeout for {url} on attempt {attempt + 1}")
120
+ except requests.exceptions.HTTPError as e:
121
+ if e.response is not None and e.response.status_code in [429, 502, 503, 504]:
122
+ logger.warning(f"HTTP error {e.response.status_code} for {url} on attempt {attempt + 1}")
123
+ else:
124
+ logger.error(f"Non-retriable HTTP error for {url}: {e}")
125
+ return e.response if e.response is not None else None
126
+ except requests.exceptions.RequestException as e:
127
+ logger.error(f"Request exception for {url} on attempt {attempt + 1}: {e}")
128
+
129
+ if attempt < MAX_RETRY_ATTEMPTS - 1:
130
+ delay = (RETRY_BACKOFF_FACTOR ** attempt) + random.uniform(0, 0.5)
131
+ logger.info(f"Retrying {url} in {delay:.2f} seconds...")
132
+ await asyncio.sleep(delay)
133
+ logger.error(f"All {MAX_RETRY_ATTEMPTS} retry attempts failed for {url}.")
134
+ return None
135
+
136
+ def get_replica_code_sync(url: str) -> Optional[str]:
137
+ try:
138
+ r = requests.get(url, timeout=DEFAULT_REQUEST_TIMEOUT, headers={"User-Agent": "NSFWDetectorClient/1.1"})
139
+ r.raise_for_status()
140
+ # This parsing is fragile
141
+ parts = r.text.split('replicas/')
142
+ if len(parts) > 1:
143
+ return parts[1].split('"};')[0]
144
+ logger.warning(f"Could not find 'replicas/' in content from {url}")
145
+ return None
146
+ except (requests.exceptions.RequestException, IndexError, KeyError) as e:
147
+ logger.error(f"Error getting replica code for {url}: {e}")
148
+ return None
149
+
150
+ async def get_replica_code(url: str) -> Optional[str]:
151
+ loop = asyncio.get_event_loop()
152
+ return await loop.run_in_executor(None, get_replica_code_sync, url)
153
+
154
+
155
+ async def parse_hf_queue_response(response_content: str) -> Optional[str]:
156
+ try:
157
+ messages = response_content.strip().split('\n')
158
+ for msg_str in reversed(messages):
159
+ if msg_str.startswith("data:"):
160
+ try:
161
+ data_json_str = msg_str[len("data:"):].strip()
162
+ if not data_json_str: continue
163
+
164
+ parsed_json = json.loads(data_json_str)
165
+ if parsed_json.get("msg") == "process_completed":
166
+ output_data = parsed_json.get("output", {}).get("data")
167
+ if output_data and isinstance(output_data, list) and len(output_data) > 0:
168
+ first_item = output_data[0]
169
+ if isinstance(first_item, dict): return first_item.get('label')
170
+ if isinstance(first_item, str): return first_item
171
+ logger.warning(f"Unexpected 'process_completed' data structure: {output_data}")
172
+ return None
173
+ except json.JSONDecodeError:
174
+ logger.debug(f"Failed to decode JSON from part of HF stream: {data_json_str[:100]}")
175
+ continue
176
+ return None
177
+ except Exception as e:
178
+ logger.error(f"Error parsing HF queue response: {e}, content: {response_content[:200]}")
179
+ return None
180
+
181
+ async def check_nsfw_single_generic(checker_name: str, img_url: str) -> Optional[str]:
182
+ config = NSFW_CHECKER_CONFIG.get(checker_name)
183
+ if not config:
184
+ logger.error(f"No configuration found for checker: {checker_name}")
185
+ return None
186
+
187
+ session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(10))
188
+
189
+ try:
190
+ if "predict_url" in config: # ZanderLewis-like
191
+ payload = config["payload_template"](img_url)
192
+ response_predict = await http_request_with_retry("POST", config["predict_url"], json=payload)
193
+ if not response_predict or response_predict.status_code != 200:
194
+ logger.error(f"{checker_name} predict call failed or returned non-200. Status: {response_predict.status_code if response_predict else 'N/A'}")
195
+ return None
196
+
197
+ json_data = response_predict.json()
198
+ event_id = json_data.get('event_id')
199
+ if not event_id:
200
+ logger.error(f"{checker_name} did not return event_id.")
201
+ return None
202
+
203
+ event_url = config["event_url_template"].format(event_id=event_id)
204
+ for _ in range(10):
205
+ await asyncio.sleep(random.uniform(1.5, 2.5)) # Randomized poll delay
206
+ response_event = await http_request_with_retry("GET", event_url, stream=True) # stream=True might not be needed if not chunking
207
+ if response_event and response_event.status_code == 200:
208
+ event_stream_content = response_event.text # Get full text
209
+ if 'data:' in event_stream_content:
210
+ final_data_str = event_stream_content.strip().split('data:')[-1].strip()
211
+ if final_data_str:
212
+ try:
213
+ parsed_list = ast.literal_eval(final_data_str)
214
+ if isinstance(parsed_list, list) and parsed_list:
215
+ return parsed_list[0].get('label')
216
+ logger.warning(f"{checker_name} parsed empty or invalid list from event stream: {final_data_str[:100]}")
217
+ except (SyntaxError, ValueError, IndexError, TypeError) as e:
218
+ logger.warning(f"{checker_name} error parsing event stream: {e}, stream: {final_data_str[:100]}")
219
+ elif response_event:
220
+ logger.warning(f"{checker_name} polling event_url returned status {response_event.status_code}")
221
+ else:
222
+ logger.warning(f"{checker_name} polling event_url got no response.")
223
+
224
+ else: # Queue-based APIs
225
+ join_url = config["queue_join_url"]
226
+ data_url_template = config["queue_data_url_template"]
227
+
228
+ if config.get("replica_code_needed"):
229
+ replica_base_url = config.get("base_url")
230
+ if not replica_base_url:
231
+ logger.error(f"{checker_name} needs replica_code but base_url is missing.")
232
+ return None
233
+ code = await get_replica_code(replica_base_url)
234
+ if not code:
235
+ logger.error(f"Failed to get replica code for {checker_name}")
236
+ return None
237
+ join_url = config["queue_join_url_template"].format(code=code)
238
+ data_url = data_url_template.format(code=code, session_hash=session_hash)
239
+ else:
240
+ data_url = data_url_template.format(session_hash=session_hash)
241
+
242
+ payload = config["payload_template"](img_url, session_hash)
243
+
244
+ response_join = await http_request_with_retry("POST", join_url, json=payload)
245
+ if not response_join or response_join.status_code != 200:
246
+ logger.error(f"{checker_name} queue/join call failed. Status: {response_join.status_code if response_join else 'N/A'}")
247
+ return None
248
+
249
+ for _ in range(15):
250
+ await asyncio.sleep(random.uniform(1.5, 2.5)) # Randomized poll delay
251
+ response_data = await http_request_with_retry("GET", data_url, stream=True) # stream=True is important here
252
+ if response_data and response_data.status_code == 200:
253
+ buffer = ""
254
+ for content_chunk in response_data.iter_content(chunk_size=1024, decode_unicode=True): # decode_unicode
255
+ if content_chunk:
256
+ buffer += content_chunk
257
+ if buffer.strip().endswith("}\n\n"): # Check for complete message block
258
+ label = await parse_hf_queue_response(buffer)
259
+ if label: return label
260
+ buffer = "" # Reset buffer after processing a block
261
+ elif response_data:
262
+ logger.warning(f"{checker_name} polling queue/data returned status {response_data.status_code}")
263
+ else:
264
+ logger.warning(f"{checker_name} polling queue/data got no response.")
265
+
266
+ logger.warning(f"{checker_name} failed to get a conclusive result for {img_url}")
267
+ return None
268
+
269
+ except Exception as e:
270
+ logger.error(f"Exception in {checker_name} for {img_url}: {e}", exc_info=True)
271
+ return None
272
+
273
+ async def check_nsfw_final_concurrent(img_url: str) -> Optional[bool]:
274
+ logger.info(f"Starting NSFW check for: {img_url}")
275
+ checker_names = [
276
+ "checker2_jamescookjr90", "checker3_zanderlewis", "checker5_phelpsgg",
277
+ "checker4_error466", "checker1_yoinked"
278
+ ]
279
+
280
+ # Wrap tasks to carry their names for better logging upon completion
281
+ named_tasks = {
282
+ name: asyncio.create_task(check_nsfw_single_generic(name, img_url))
283
+ for name in checker_names
284
+ }
285
+
286
+ # To store if any SFW result was found
287
+ sfw_found_by_any_checker = False
288
+
289
+ for task_name in named_tasks: # Iterate in defined order for potential preference
290
+ try:
291
+ label = await named_tasks[task_name] # Wait for this specific task
292
+ logger.info(f"Checker '{task_name}' result for {img_url}: {label}")
293
+ if label:
294
+ label_lower = label.lower()
295
+ if 'nsfw' in label_lower:
296
+ logger.info(f"NSFW detected by '{task_name}' for {img_url}. Final: True.")
297
+ # Optionally cancel other tasks if desired:
298
+ # for t_name, t_obj in named_tasks.items():
299
+ # if t_name != task_name and not t_obj.done(): t_obj.cancel()
300
+ return True
301
+ if 'sfw' in label_lower or 'safe' in label_lower:
302
+ sfw_found_by_any_checker = True
303
+ # Don't return False yet, wait for other checkers.
304
+ # If label is None or not nsfw/sfw, continue to next checker's result
305
+ except asyncio.CancelledError:
306
+ logger.info(f"Checker '{task_name}' was cancelled for {img_url}.")
307
+ except Exception as e:
308
+ logger.error(f"Error processing result from checker '{task_name}' for {img_url}: {e}")
309
+
310
+ if sfw_found_by_any_checker: # No NSFW detected by any, but at least one said SFW
311
+ logger.info(f"SFW confirmed for {img_url} (no NSFW detected, at least one SFW). Final: False.")
312
+ return False
313
+
314
+ logger.warning(f"All NSFW checkers inconclusive or failed for {img_url}. Final: None.")
315
+ return None
316
+
317
+
318
+ # --- Video Processing Logic ---
319
+ BASE_FRAMES_DIR = "/tmp/video_frames_service_advanced_v2"
320
+ os.makedirs(BASE_FRAMES_DIR, exist_ok=True)
321
+ app.mount("/static_frames", StaticFiles(directory=BASE_FRAMES_DIR), name="static_frames")
322
+
323
+ def extract_frames_sync(video_path, num_frames_to_extract, request_specific_frames_dir):
324
+ vidcap = cv2.VideoCapture(video_path)
325
+ if not vidcap.isOpened():
326
+ logger.error(f"Cannot open video file: {video_path}")
327
+ return []
328
+ total_frames_in_video = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
329
+ extracted_frame_paths = []
330
+
331
+ if total_frames_in_video == 0:
332
+ logger.warning(f"Video {video_path} has no frames.")
333
+ vidcap.release()
334
+ return []
335
+
336
+ # Ensure num_frames_to_extract does not exceed total_frames_in_video if total_frames_in_video is small
337
+ actual_frames_to_extract = min(num_frames_to_extract, total_frames_in_video)
338
+ if actual_frames_to_extract == 0 and total_frames_in_video > 0: # Edge case: if num_frames is 0 but video has frames
339
+ actual_frames_to_extract = 1 # Extract at least one frame if possible
340
+
341
+ if actual_frames_to_extract == 0: # If still zero (e.g. total_frames_in_video was 0)
342
+ vidcap.release()
343
+ return []
344
+
345
+
346
+ for i in range(actual_frames_to_extract):
347
+ # Distribute frame extraction
348
+ frame_number = int(i * total_frames_in_video / actual_frames_to_extract) if actual_frames_to_extract > 0 else 0
349
+ # Ensure frame_number is within bounds
350
+ frame_number = min(frame_number, total_frames_in_video -1)
351
+
352
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
353
+ success, image = vidcap.read()
354
+ if success:
355
+ frame_filename = os.path.join(request_specific_frames_dir, f"frame_{uuid.uuid4().hex}.jpg")
356
+ if cv2.imwrite(frame_filename, image):
357
+ extracted_frame_paths.append(frame_filename)
358
+ else:
359
+ logger.error(f"Failed to write frame: {frame_filename}")
360
+ else:
361
+ logger.warning(f"Failed to read frame at position {frame_number} from {video_path}. Total frames: {total_frames_in_video}")
362
+ # Don't break immediately, try next calculated frame unless it's a persistent issue
363
+ vidcap.release()
364
+ return extracted_frame_paths
365
+
366
+ async def process_video_core(task_id: str, video_path_on_disk: str, num_frames_to_analyze: int, app_base_url: str):
367
+ tasks_db[task_id].update({"status": "processing", "message": "Extracting frames..."})
368
+
369
+ request_frames_subdir = os.path.join(BASE_FRAMES_DIR, task_id)
370
+ os.makedirs(request_frames_subdir, exist_ok=True)
371
+
372
+ extracted_frames_disk_paths = []
373
+ try:
374
+ loop = asyncio.get_event_loop()
375
+ extracted_frames_disk_paths = await loop.run_in_executor(
376
+ None, extract_frames_sync, video_path_on_disk, num_frames_to_analyze, request_frames_subdir
377
+ )
378
+
379
+ if not extracted_frames_disk_paths:
380
+ tasks_db[task_id].update({"status": "failed", "message": "No frames could be extracted."})
381
+ logger.error(f"Task {task_id}: No frames extracted from {video_path_on_disk}")
382
+ # Clean up the video file if no frames extracted
383
+ if os.path.exists(video_path_on_disk): os.remove(video_path_on_disk)
384
+ return
385
+
386
+ tasks_db[task_id].update({
387
+ "status": "processing",
388
+ "message": f"Analyzing {len(extracted_frames_disk_paths)} frames..."
389
+ })
390
+
391
+ nsfw_count = 0
392
+ frame_results_list = []
393
+ base_url_for_static_frames = f"{app_base_url.rstrip('/')}/static_frames/{task_id}"
394
+
395
+ analysis_coroutines = []
396
+ for frame_disk_path in extracted_frames_disk_paths:
397
+ frame_filename_only = os.path.basename(frame_disk_path)
398
+ img_http_url = f"{base_url_for_static_frames}/{frame_filename_only}"
399
+ analysis_coroutines.append(check_nsfw_final_concurrent(img_http_url))
400
+
401
+ nsfw_detection_results = await asyncio.gather(*analysis_coroutines, return_exceptions=True)
402
+
403
+ for i, detection_result in enumerate(nsfw_detection_results):
404
+ frame_disk_path = extracted_frames_disk_paths[i]
405
+ frame_filename_only = os.path.basename(frame_disk_path)
406
+ img_http_url = f"{base_url_for_static_frames}/{frame_filename_only}"
407
+ is_nsfw_str = "unknown"
408
+
409
+ if isinstance(detection_result, Exception):
410
+ logger.error(f"Task {task_id}: Error analyzing frame {img_http_url}: {detection_result}")
411
+ is_nsfw_str = "error"
412
+ else: # detection_result is True, False, or None
413
+ if detection_result is True:
414
+ nsfw_count += 1
415
+ is_nsfw_str = "true"
416
+ elif detection_result is False:
417
+ is_nsfw_str = "false"
418
+
419
+ frame_results_list.append({"frame_url": img_http_url, "nsfw_detected": is_nsfw_str})
420
+
421
+ result_summary = {
422
+ "nsfw_count": nsfw_count,
423
+ "total_frames_analyzed": len(extracted_frames_disk_paths),
424
+ "frames": frame_results_list
425
+ }
426
+ tasks_db[task_id].update({"status": "completed", "result": result_summary, "message": "Processing complete."})
427
+ logger.info(f"Task {task_id}: Processing complete. Result: {result_summary}")
428
+
429
+ except Exception as e:
430
+ logger.error(f"Task {task_id}: Unhandled exception in process_video_core: {e}", exc_info=True)
431
+ tasks_db[task_id].update({"status": "failed", "message": f"An internal error occurred: {str(e)}"})
432
+ finally:
433
+ if os.path.exists(video_path_on_disk):
434
+ try:
435
+ os.remove(video_path_on_disk)
436
+ logger.info(f"Task {task_id}: Cleaned up video file: {video_path_on_disk}")
437
+ except OSError as e_remove:
438
+ logger.error(f"Task {task_id}: Error cleaning up video file {video_path_on_disk}: {e_remove}")
439
+ # Consider a separate job for cleaning up frame directories (request_frames_subdir) after a TTL
440
+
441
+
442
+ # --- API Request/Response Models ---
443
+ class VideoProcessRequest(BaseModel):
444
+ video_url: Optional[str] = Field(None, description="Publicly accessible URL of the video to process.")
445
+ num_frames: int = Field(10, gt=0, le=50, description="Number of frames to extract (1-50). Max 50 for performance.") # Reduced max
446
+ app_base_url: str = Field(..., description="Public base URL of this API service (e.g., https://your-username-your-space-name.hf.space).")
447
+
448
+ class TaskCreationResponse(BaseModel):
449
+ task_id: str
450
+ status_url: str
451
+ message: str
452
+
453
+ class FrameResult(BaseModel):
454
+ frame_url: str
455
+ nsfw_detected: str
456
+
457
+ class VideoProcessResult(BaseModel):
458
+ nsfw_count: int
459
+ total_frames_analyzed: int
460
+ frames: List[FrameResult]
461
+
462
+ class TaskStatusResponse(BaseModel):
463
+ task_id: str
464
+ status: str
465
+ message: Optional[str] = None
466
+ result: Optional[VideoProcessResult] = None
467
+
468
+
469
+ # --- API Endpoints ---
470
+ @app.post("/process_video_async", response_model=TaskCreationResponse, status_code=202)
471
+ async def process_video_from_url_async_endpoint(
472
+ request_data: VideoProcessRequest, # Changed from 'request' to avoid conflict with FastAPI's Request object
473
+ background_tasks: BackgroundTasks
474
+ ):
475
+ if not request_data.video_url:
476
+ raise HTTPException(status_code=400, detail="video_url must be provided.")
477
+
478
+ task_id = str(uuid.uuid4())
479
+ tasks_db[task_id] = {"status": "pending", "message": "Task received, preparing for download."}
480
+
481
+ temp_video_file_path = None
482
+ try:
483
+ # Create a temporary file path for the downloaded video
484
+ # The actual download will also be part of the background task to avoid blocking.
485
+ # For now, keeping initial download here for simplicity of passing path.
486
+ # A more robust way: background_tasks.add_task(download_and_then_process, task_id, request_data.video_url, ...)
487
+
488
+ # Using a temporary directory specific to this task for the downloaded video
489
+ task_download_dir = os.path.join(BASE_FRAMES_DIR, "_video_downloads", task_id)
490
+ os.makedirs(task_download_dir, exist_ok=True)
491
+ # Suffix from URL or default
492
+ video_suffix = os.path.splitext(request_data.video_url.split("?")[0])[-1] or ".mp4" # Basic suffix extraction
493
+ if not video_suffix.startswith("."): video_suffix = "." + video_suffix
494
+
495
+
496
+ temp_video_file_path = os.path.join(task_download_dir, f"downloaded_video{video_suffix}")
497
+
498
+ logger.info(f"Task {task_id}: Attempting to download video from {request_data.video_url} to {temp_video_file_path}")
499
+
500
+ dl_response = await http_request_with_retry("GET", request_data.video_url, stream=True)
501
+ if not dl_response or dl_response.status_code != 200:
502
+ if os.path.exists(task_download_dir): shutil.rmtree(task_download_dir)
503
+ tasks_db[task_id].update({"status": "failed", "message": f"Failed to download video. Status: {dl_response.status_code if dl_response else 'N/A'}"})
504
+ raise HTTPException(status_code=400, detail=f"Error downloading video: Status {dl_response.status_code if dl_response else 'N/A'}")
505
+
506
+ with open(temp_video_file_path, "wb") as f:
507
+ for chunk in dl_response.iter_content(chunk_size=8192*4): # Increased chunk size
508
+ f.write(chunk)
509
+ logger.info(f"Task {task_id}: Video downloaded to {temp_video_file_path}")
510
+
511
+ background_tasks.add_task(process_video_core, task_id, temp_video_file_path, request_data.num_frames, request_data.app_base_url)
512
+
513
+ status_url_path = app.url_path_for("get_task_status_endpoint", task_id=task_id)
514
+ full_status_url = str(request_data.app_base_url.rstrip('/') + status_url_path)
515
+
516
+ return TaskCreationResponse(
517
+ task_id=task_id,
518
+ status_url=full_status_url,
519
+ message="Video processing task accepted and started in background."
520
+ )
521
+
522
+ except requests.exceptions.RequestException as e:
523
+ if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path))
524
+ tasks_db[task_id].update({"status": "failed", "message": f"Video download error: {e}"})
525
+ raise HTTPException(status_code=400, detail=f"Error downloading video: {e}")
526
+ except Exception as e:
527
+ if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path))
528
+ logger.error(f"Task {task_id}: Unexpected error during task submission: {e}", exc_info=True)
529
+ tasks_db[task_id].update({"status": "failed", "message": "Internal server error during task submission."})
530
+ raise HTTPException(status_code=500, detail="Internal server error during task submission.")
531
+
532
+
533
+ @app.post("/upload_video_async", response_model=TaskCreationResponse, status_code=202)
534
+ async def upload_video_async_endpoint(
535
+ background_tasks: BackgroundTasks,
536
+ app_base_url: str = Form(..., description="Public base URL of this API service."),
537
+ num_frames: int = Form(10, gt=0, le=50, description="Number of frames to extract (1-50)."),
538
+ video_file: UploadFile = File(..., description="Video file to upload and process.")
539
+ ):
540
+ if not video_file.content_type or not video_file.content_type.startswith("video/"):
541
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.")
542
+
543
+ task_id = str(uuid.uuid4())
544
+ tasks_db[task_id] = {"status": "pending", "message": "Task received, saving uploaded video."}
545
+ temp_video_file_path = None
546
+ try:
547
+ upload_dir = os.path.join(BASE_FRAMES_DIR, "_video_uploads", task_id) # Task-specific upload dir
548
+ os.makedirs(upload_dir, exist_ok=True)
549
+
550
+ suffix = os.path.splitext(video_file.filename)[1] if video_file.filename and "." in video_file.filename else ".mp4"
551
+ if not suffix.startswith("."): suffix = "." + suffix
552
+
553
+ temp_video_file_path = os.path.join(upload_dir, f"uploaded_video{suffix}")
554
+
555
+ with open(temp_video_file_path, "wb") as buffer:
556
+ shutil.copyfileobj(video_file.file, buffer)
557
+ logger.info(f"Task {task_id}: Video uploaded and saved to {temp_video_file_path}")
558
+
559
+ background_tasks.add_task(process_video_core, task_id, temp_video_file_path, num_frames, app_base_url)
560
+
561
+ status_url_path = app.url_path_for("get_task_status_endpoint", task_id=task_id)
562
+ full_status_url = str(app_base_url.rstrip('/') + status_url_path)
563
+ return TaskCreationResponse(
564
+ task_id=task_id,
565
+ status_url=full_status_url,
566
+ message="Video upload accepted and processing started in background."
567
+ )
568
+ except Exception as e:
569
+ if temp_video_file_path and os.path.exists(os.path.dirname(temp_video_file_path)): shutil.rmtree(os.path.dirname(temp_video_file_path))
570
+ logger.error(f"Task {task_id}: Error handling video upload: {e}", exc_info=True)
571
+ tasks_db[task_id].update({"status": "failed", "message": "Internal server error during video upload."})
572
+ raise HTTPException(status_code=500, detail=f"Error processing uploaded file: {e}")
573
+ finally:
574
+ if video_file:
575
+ await video_file.close()
576
+
577
+
578
+ @app.get("/tasks/{task_id}/status", response_model=TaskStatusResponse)
579
+ async def get_task_status_endpoint(task_id: str):
580
+ task = tasks_db.get(task_id)
581
+ if not task:
582
+ raise HTTPException(status_code=404, detail="Task not found.")
583
+ return TaskStatusResponse(task_id=task_id, **task)
584
+
585
+ # --- Homepage Endpoint ---
586
+ @app.get("/", response_class=HTMLResponse)
587
+ async def read_root(fastapi_request: Request): # Renamed from 'request' to avoid conflict
588
+ # Try to determine app_base_url automatically if possible (might be tricky behind proxies)
589
+ # For Hugging Face, the X-Forwarded-Host or similar headers might be useful.
590
+ # A simpler approach for HF is to have the user provide it or construct it.
591
+ # For the example curl, let's use a placeholder.
592
+
593
+ # Construct a placeholder app_base_url for examples if running on HF
594
+ # This is a guess; ideally, the Space provides this as an env var.
595
+ hf_space_name = os.getenv("SPACE_ID", "your-username-your-space-name")
596
+ if hf_space_name == "your-username-your-space-name" and fastapi_request.headers.get("host"):
597
+ # if host header is like user-space.hf.space, use that
598
+ host = fastapi_request.headers.get("host")
599
+ if host and ".hf.space" in host:
600
+ hf_space_name = host
601
+
602
+ # If running locally, use localhost
603
+ scheme = fastapi_request.url.scheme
604
+ port = fastapi_request.url.port
605
+ host = fastapi_request.url.hostname
606
+
607
+ if host == "localhost" or host == "127.0.0.1":
608
+ example_app_base_url = f"{scheme}://{host}:{port}" if port else f"{scheme}://{host}"
609
+ else: # Assume it's deployed, e.g. on HF
610
+ example_app_base_url = f"https://{hf_space_name}.hf.space" if ".hf.space" not in hf_space_name else f"https://{hf_space_name}"
611
+
612
+
613
+ html_content = f"""
614
+ <!DOCTYPE html>
615
+ <html lang="en">
616
+ <head>
617
+ <meta charset="UTF-8">
618
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
619
+ <title>NSFW Video Detector API</title>
620
+ <style>
621
+ body {{ font-family: Arial, sans-serif; margin: 20px; line-height: 1.6; background-color: #f4f4f4; color: #333; }}
622
+ .container {{ background-color: #fff; padding: 20px; border-radius: 8px; box-shadow: 0 0 10px rgba(0,0,0,0.1); }}
623
+ h1, h2, h3 {{ color: #333; }}
624
+ h1 {{ text-align: center; border-bottom: 2px solid #eee; padding-bottom: 10px;}}
625
+ h2 {{ border-bottom: 1px solid #eee; padding-bottom: 5px; margin-top: 30px;}}
626
+ code {{ background-color: #eef; padding: 2px 6px; border-radius: 4px; font-family: "Courier New", Courier, monospace;}}
627
+ pre {{ background-color: #eef; padding: 15px; border-radius: 4px; overflow-x: auto; border: 1px solid #ddd; }}
628
+ .endpoint {{ margin-bottom: 20px; }}
629
+ .param {{ font-weight: bold; }}
630
+ .note {{ background-color: #fff9c4; border-left: 4px solid #fdd835; padding: 10px; margin: 15px 0; border-radius:4px; }}
631
+ .tip {{ background-color: #e8f5e9; border-left: 4px solid #4caf50; padding: 10px; margin: 15px 0; border-radius:4px; }}
632
+ table {{ width: 100%; border-collapse: collapse; margin-top:10px; }}
633
+ th, td {{ text-align: left; padding: 8px; border-bottom: 1px solid #ddd; }}
634
+ th {{ background-color: #f0f0f0; }}
635
+ a {{ color: #007bff; text-decoration: none; }}
636
+ a:hover {{ text-decoration: underline; }}
637
+ </style>
638
+ </head>
639
+ <body>
640
+ <div class="container">
641
+ <h1>NSFW Video Detector API</h1>
642
+ <p>This API allows you to process videos to detect Not Suitable For Work (NSFW) content. It works asynchronously: you submit a video (via URL or direct upload), receive a task ID, and then poll a status endpoint to get the results.</p>
643
+
644
+ <div class="note">
645
+ <p><span class="param">Important:</span> The <code>app_base_url</code> parameter is crucial. It must be the public base URL where this API service is accessible. For example, if your Hugging Face Space URL is <code>https://your-username-your-space-name.hf.space</code>, then that's your <code>app_base_url</code>.</p>
646
+ <p>Current detected example base URL for instructions: <code>{example_app_base_url}</code> (This is a guess, please verify your actual public URL).</p>
647
+ </div>
648
+
649
+ <h2>Endpoints</h2>
650
+
651
+ <div class="endpoint">
652
+ <h3>1. Process Video from URL (Asynchronous)</h3>
653
+ <p><code>POST /process_video_async</code></p>
654
+ <p>Submits a video from a public URL for NSFW analysis.</p>
655
+ <h4>Request Body (JSON):</h4>
656
+ <table>
657
+ <tr><th>Parameter</th><th>Type</th><th>Default</th><th>Description</th></tr>
658
+ <tr><td><span class="param">video_url</span></td><td>string</td><td><em>Required</em></td><td>Publicly accessible URL of the video.</td></tr>
659
+ <tr><td><span class="param">num_frames</span></td><td>integer</td><td>10</td><td>Number of frames to extract and analyze (1-50).</td></tr>
660
+ <tr><td><span class="param">app_base_url</span></td><td>string</td><td><em>Required</em></td><td>The public base URL of this API service.</td></tr>
661
+ </table>
662
+ <h4>Example using <code>curl</code>:</h4>
663
+ <pre><code>curl -X POST "{example_app_base_url}/process_video_async" \\
664
+ -H "Content-Type: application/json" \\
665
+ -d '{{
666
+ "video_url": "YOUR_PUBLIC_VIDEO_URL_HERE.mp4",
667
+ "num_frames": 5,
668
+ "app_base_url": "{example_app_base_url}"
669
+ }}'</code></pre>
670
+ </div>
671
+
672
+ <div class="endpoint">
673
+ <h3>2. Upload Video File (Asynchronous)</h3>
674
+ <p><code>POST /upload_video_async</code></p>
675
+ <p>Uploads a video file directly for NSFW analysis.</p>
676
+ <h4>Request Body (Multipart Form-Data):</h4>
677
+ <table>
678
+ <tr><th>Parameter</th><th>Type</th><th>Default</th><th>Description</th></tr>
679
+ <tr><td><span class="param">video_file</span></td><td>file</td><td><em>Required</em></td><td>The video file to upload.</td></tr>
680
+ <tr><td><span class="param">num_frames</span></td><td>integer</td><td>10</td><td>Number of frames to extract (1-50).</td></tr>
681
+ <tr><td><span class="param">app_base_url</span></td><td>string</td><td><em>Required</em></td><td>The public base URL of this API service.</td></tr>
682
+ </table>
683
+ <h4>Example using <code>curl</code>:</h4>
684
+ <pre><code>curl -X POST "{example_app_base_url}/upload_video_async" \\
685
+ -F "video_file=@/path/to/your/video.mp4" \\
686
+ -F "num_frames=5" \\
687
+ -F "app_base_url={example_app_base_url}"</code></pre>
688
+ </div>
689
+
690
+ <div class="tip">
691
+ <h4>Response for Task Creation (for both URL and Upload):</h4>
692
+ <p>If successful (HTTP 202 Accepted), the API will respond with:</p>
693
+ <pre><code>{{
694
+ "task_id": "some-unique-task-id",
695
+ "status_url": "{example_app_base_url}/tasks/some-unique-task-id/status",
696
+ "message": "Video processing task accepted and started in background."
697
+ }}</code></pre>
698
+ </div>
699
+
700
+ <div class="endpoint">
701
+ <h3>3. Get Task Status and Result</h3>
702
+ <p><code>GET /tasks/&lt;task_id&gt;/status</code></p>
703
+ <p>Poll this endpoint to check the status of a processing task and retrieve the result once completed.</p>
704
+ <h4>Example using <code>curl</code>:</h4>
705
+ <pre><code>curl -X GET "{example_app_base_url}/tasks/some-unique-task-id/status"</code></pre>
706
+ <h4>Possible Statuses:</h4>
707
+ <ul>
708
+ <li><code>pending</code>: Task is queued.</li>
709
+ <li><code>processing</code>: Task is actively being processed (downloading, extracting frames, analyzing).</li>
710
+ <li><code>completed</code>: Task finished successfully. Results are available in the <code>result</code> field.</li>
711
+ <li><code>failed</code>: Task failed. Check the <code>message</code> field for details.</li>
712
+ </ul>
713
+ <h4>Example Response (Status: <code>completed</code>):</h4>
714
+ <pre><code>{{
715
+ "task_id": "some-unique-task-id",
716
+ "status": "completed",
717
+ "message": "Processing complete.",
718
+ "result": {{
719
+ "nsfw_count": 1,
720
+ "total_frames_analyzed": 5,
721
+ "frames": [
722
+ {{
723
+ "frame_url": "{example_app_base_url}/static_frames/some-unique-task-id/frame_uuid1.jpg",
724
+ "nsfw_detected": "false"
725
+ }},
726
+ {{
727
+ "frame_url": "{example_app_base_url}/static_frames/some-unique-task-id/frame_uuid2.jpg",
728
+ "nsfw_detected": "true"
729
+ }}
730
+ // ... more frames
731
+ ]
732
+ }}
733
+ }}</code></pre>
734
+ <h4>Example Response (Status: <code>processing</code>):</h4>
735
+ <pre><code>{{
736
+ "task_id": "some-unique-task-id",
737
+ "status": "processing",
738
+ "message": "Analyzing 5 frames...",
739
+ "result": null
740
+ }}</code></pre>
741
+ </div>
742
+ <p style="text-align:center; margin-top:30px; font-size:0.9em; color:#777;">API Version: {app.version}</p>
743
+ </div>
744
+ </body>
745
+ </html>
746
+ """
747
+ # If using Jinja2 templates:
748
+ # if templates:
749
+ # return templates.TemplateResponse("index.html", {"request": fastapi_request, "app_version": app.version, "example_app_base_url": example_app_base_url})
750
+ # else:
751
+ # return HTMLResponse(content=html_content, status_code=200)
752
+ return HTMLResponse(content=html_content, status_code=200)
753
+
754
+
755
+ # Example of how to run for local development:
756
+ # 1. Ensure you have a 'templates/index.html' file or the fallback HTML will be used.
757
+ # 2. Run: uvicorn main:app --reload --host 0.0.0.0 --port 8000
758
+ # (assuming your file is named main.py)
759
+ # Requirements: fastapi uvicorn[standard] opencv-python requests pydantic python-multipart (for Form/File uploads)