Spaces:
Runtime error
Runtime error
import gradio as gr | |
import cv2 # OpenCV for video processing | |
import os # For interacting with the file system | |
import requests # For making HTTP requests | |
import random | |
import string | |
import json | |
import shutil # For file operations | |
import ast # For safely evaluating string literals | |
import uuid # For generating unique filenames | |
import asyncio # For concurrent operations | |
import time # For retries and delays | |
import logging # For structured logging | |
# --- Logging Configuration --- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Configuration --- | |
# Directory for temporarily storing extracted frames. | |
GRADIO_TEMP_FRAME_DIR = "/tmp/gradio_nsfw_frames_advanced" | |
os.makedirs(GRADIO_TEMP_FRAME_DIR, exist_ok=True) | |
# The public URL of this Gradio Space. Crucial for external NSFW checkers. | |
# Set via environment variable or update placeholder if hardcoding. | |
APP_BASE_URL = os.getenv("APP_BASE_URL", "YOUR_GRADIO_SPACE_PUBLIC_URL_HERE") | |
DEFAULT_REQUEST_TIMEOUT = 20 | |
MAX_RETRY_ATTEMPTS = 3 | |
RETRY_BACKOFF_FACTOR = 2 | |
# --- NSFW Checker Configuration (from FastAPI version) --- | |
NSFW_CHECKER_CONFIG = { | |
"checker1_yoinked": { | |
"queue_join_url": "https://yoinked-da-nsfw-checker.hf.space/queue/join", | |
"queue_data_url_template": "https://yoinked-da-nsfw-checker.hf.space/queue/data?session_hash={session_hash}", | |
"payload_template": lambda img_url, session_hash: { | |
'data': [{'path': img_url}, "chen-convnext", 0.5, True, True], | |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 12 | |
} | |
}, | |
"checker2_jamescookjr90": { | |
"queue_join_url": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/join", | |
"queue_data_url_template": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}", | |
"payload_template": lambda img_url, session_hash: { | |
'data': [{'path': img_url}], | |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9 | |
} | |
}, | |
"checker3_zanderlewis": { | |
"predict_url": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict", | |
"event_url_template": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict/{event_id}", | |
"payload_template": lambda img_url: {'data': [{'path': img_url}]} | |
}, | |
"checker4_error466": { | |
"base_url": "https://error466-falconsai-nsfw-image-detection.hf.space", | |
"replica_code_needed": True, | |
"queue_join_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/join", | |
"queue_data_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/data?session_hash={session_hash}", | |
"payload_template": lambda img_url, session_hash: { | |
'data': [{'path': img_url}], | |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 58 | |
} | |
}, | |
"checker5_phelpsgg": { # Original user code had 'imseldrith', FastAPI had 'phelpsgg'. Using phelpsgg from FastAPI. | |
"queue_join_url": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/join", | |
"queue_data_url_template": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}", | |
"payload_template": lambda img_url, session_hash: { | |
'data': [{'path': img_url}], | |
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9 | |
} | |
} | |
} | |
# --- Helper Functions (from FastAPI version, adapted for Gradio context) --- | |
async def http_request_with_retry(method: str, url: str, **kwargs) -> Optional[requests.Response]: | |
headers = kwargs.pop("headers", {}) | |
headers.setdefault("User-Agent", "GradioNSFWClient/1.0") | |
for attempt in range(MAX_RETRY_ATTEMPTS): | |
try: | |
async with asyncio.Semaphore(10): | |
loop = asyncio.get_event_loop() | |
response = await loop.run_in_executor( | |
None, | |
lambda: requests.request(method, url, headers=headers, timeout=DEFAULT_REQUEST_TIMEOUT, **kwargs) | |
) | |
response.raise_for_status() | |
return response | |
except requests.exceptions.Timeout: | |
logger.warning(f"Request timeout for {url} on attempt {attempt + 1}") | |
except requests.exceptions.HTTPError as e: | |
if e.response is not None and e.response.status_code in [429, 502, 503, 504]: | |
logger.warning(f"HTTP error {e.response.status_code} for {url} on attempt {attempt + 1}") | |
else: | |
logger.error(f"Non-retriable HTTP error for {url}: {e}") | |
return e.response if e.response is not None else None | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Request exception for {url} on attempt {attempt + 1}: {e}") | |
if attempt < MAX_RETRY_ATTEMPTS - 1: | |
delay = (RETRY_BACKOFF_FACTOR ** attempt) + random.uniform(0, 0.5) | |
logger.info(f"Retrying {url} in {delay:.2f} seconds...") | |
await asyncio.sleep(delay) | |
logger.error(f"All {MAX_RETRY_ATTEMPTS} retry attempts failed for {url}.") | |
return None | |
def get_replica_code_sync(url: str) -> Optional[str]: | |
try: | |
r = requests.get(url, timeout=DEFAULT_REQUEST_TIMEOUT, headers={"User-Agent": "GradioNSFWClient/1.0"}) | |
r.raise_for_status() | |
parts = r.text.split('replicas/') | |
if len(parts) > 1: | |
return parts[1].split('"};')[0] | |
logger.warning(f"Could not find 'replicas/' in content from {url}") | |
return None | |
except (requests.exceptions.RequestException, IndexError, KeyError) as e: | |
logger.error(f"Error getting replica code for {url}: {e}") | |
return None | |
async def get_replica_code(url: str) -> Optional[str]: | |
loop = asyncio.get_event_loop() | |
return await loop.run_in_executor(None, get_replica_code_sync, url) | |
async def parse_hf_queue_response(response_content: str) -> Optional[str]: | |
try: | |
messages = response_content.strip().split('\n') | |
for msg_str in reversed(messages): | |
if msg_str.startswith("data:"): | |
try: | |
data_json_str = msg_str[len("data:"):].strip() | |
if not data_json_str: continue | |
parsed_json = json.loads(data_json_str) | |
if parsed_json.get("msg") == "process_completed": | |
output_data = parsed_json.get("output", {}).get("data") | |
if output_data and isinstance(output_data, list) and len(output_data) > 0: | |
first_item = output_data[0] | |
if isinstance(first_item, dict): return first_item.get('label') | |
if isinstance(first_item, str): return first_item | |
logger.warning(f"Unexpected 'process_completed' data structure: {output_data}") | |
return None | |
except json.JSONDecodeError: | |
logger.debug(f"Failed to decode JSON from part of HF stream: {data_json_str[:100]}") | |
continue | |
return None | |
except Exception as e: | |
logger.error(f"Error parsing HF queue response: {e}, content: {response_content[:200]}") | |
return None | |
async def check_nsfw_single_generic(checker_name: str, img_url: str) -> Optional[str]: | |
config = NSFW_CHECKER_CONFIG.get(checker_name) | |
if not config: | |
logger.error(f"No configuration found for checker: {checker_name}") | |
return None | |
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(10)) | |
try: | |
if "predict_url" in config: # ZanderLewis-like | |
payload = config["payload_template"](img_url) | |
response_predict = await http_request_with_retry("POST", config["predict_url"], json=payload) | |
if not response_predict or response_predict.status_code != 200: | |
logger.error(f"{checker_name} predict call failed. Status: {response_predict.status_code if response_predict else 'N/A'}") | |
return None | |
json_data = response_predict.json() | |
event_id = json_data.get('event_id') | |
if not event_id: | |
logger.error(f"{checker_name} did not return event_id.") | |
return None | |
event_url = config["event_url_template"].format(event_id=event_id) | |
for _ in range(10): | |
await asyncio.sleep(random.uniform(1.5, 2.5)) | |
response_event = await http_request_with_retry("GET", event_url) # Removed stream=True as iter_content not used directly | |
if response_event and response_event.status_code == 200: | |
event_stream_content = response_event.text | |
if 'data:' in event_stream_content: | |
final_data_str = event_stream_content.strip().split('data:')[-1].strip() | |
if final_data_str: | |
try: | |
parsed_list = ast.literal_eval(final_data_str) | |
if isinstance(parsed_list, list) and parsed_list and isinstance(parsed_list[0], dict): | |
return parsed_list[0].get('label') | |
logger.warning(f"{checker_name} parsed non-list or empty list from event stream: {final_data_str[:100]}") | |
except (SyntaxError, ValueError, IndexError, TypeError) as e: | |
logger.warning(f"{checker_name} error parsing event stream: {e}, stream: {final_data_str[:100]}") | |
elif response_event: | |
logger.warning(f"{checker_name} polling event_url returned status {response_event.status_code}") | |
else: | |
logger.warning(f"{checker_name} polling event_url got no response.") | |
else: # Queue-based APIs | |
join_url = config["queue_join_url"] | |
data_url_template = config["queue_data_url_template"] | |
if config.get("replica_code_needed"): | |
replica_base_url = config.get("base_url") | |
if not replica_base_url: | |
logger.error(f"{checker_name} needs replica_code but base_url is missing.") | |
return None | |
code = await get_replica_code(replica_base_url) | |
if not code: | |
logger.error(f"Failed to get replica code for {checker_name}") | |
return None | |
join_url = config["queue_join_url_template"].format(code=code) | |
data_url = data_url_template.format(code=code, session_hash=session_hash) | |
else: | |
data_url = data_url_template.format(session_hash=session_hash) | |
payload = config["payload_template"](img_url, session_hash) | |
response_join = await http_request_with_retry("POST", join_url, json=payload) | |
if not response_join or response_join.status_code != 200: | |
logger.error(f"{checker_name} queue/join call failed. Status: {response_join.status_code if response_join else 'N/A'}") | |
return None | |
for _ in range(15): | |
await asyncio.sleep(random.uniform(1.5, 2.5)) | |
response_data = await http_request_with_retry("GET", data_url, stream=True) | |
if response_data and response_data.status_code == 200: | |
buffer = "" | |
# iter_content is synchronous, but http_request_with_retry runs it in executor | |
for content_chunk in response_data.iter_content(chunk_size=1024, decode_unicode=True): | |
if content_chunk: | |
buffer += content_chunk | |
if buffer.strip().endswith("}\n\n"): | |
label = await parse_hf_queue_response(buffer) # parse_hf_queue_response is async | |
if label: return label | |
buffer = "" | |
elif response_data: | |
logger.warning(f"{checker_name} polling queue/data returned status {response_data.status_code}") | |
else: | |
logger.warning(f"{checker_name} polling queue/data got no response.") | |
logger.warning(f"{checker_name} failed to get a conclusive result for {img_url}") | |
return None | |
except Exception as e: | |
logger.error(f"Exception in {checker_name} for {img_url}: {e}", exc_info=True) | |
return None | |
async def check_nsfw_final_concurrent(img_url: str) -> Optional[bool]: | |
logger.info(f"Starting NSFW check for: {img_url}") | |
# Prioritized list from FastAPI version | |
checker_names = [ | |
"checker2_jamescookjr90", "checker3_zanderlewis", "checker5_phelpsgg", | |
"checker4_error466", "checker1_yoinked" | |
] | |
named_tasks = { | |
name: asyncio.create_task(check_nsfw_single_generic(name, img_url)) | |
for name in checker_names | |
} | |
sfw_found_by_any_checker = False | |
# Iterate and await tasks. Since as_completed is not used, order of results depends on await order. | |
# For true "first result wins" or concurrent processing, as_completed or gather is better. | |
# This simplified loop awaits them one by one based on checker_names order. | |
for task_name in checker_names: | |
try: | |
label = await named_tasks[task_name] | |
logger.info(f"Checker '{task_name}' result for {img_url}: {label}") | |
if label: | |
label_lower = label.lower() | |
if 'nsfw' in label_lower: | |
logger.info(f"NSFW detected by '{task_name}' for {img_url}. Final: True.") | |
# Cancel remaining tasks | |
for t_name_to_cancel, t_obj_to_cancel in named_tasks.items(): | |
if t_name_to_cancel != task_name and not t_obj_to_cancel.done(): | |
t_obj_to_cancel.cancel() | |
return True | |
if 'sfw' in label_lower or 'safe' in label_lower: | |
sfw_found_by_any_checker = True | |
except asyncio.CancelledError: | |
logger.info(f"Checker '{task_name}' was cancelled for {img_url}.") | |
except Exception as e: | |
logger.error(f"Error processing result from checker '{task_name}' for {img_url}: {e}") | |
if sfw_found_by_any_checker: | |
logger.info(f"SFW confirmed for {img_url} (no NSFW detected by any checker, at least one SFW). Final: False.") | |
return False | |
logger.warning(f"All NSFW checkers inconclusive or failed for {img_url}. Final: None.") | |
return None | |
# --- Gradio Specific Functions --- | |
def extract_frames_sync(video_path: str, num_frames_to_extract: int, progress:gr.Progress=None) -> list: | |
if progress: progress(0, desc="Starting frame extraction...") | |
vidcap = cv2.VideoCapture(video_path) | |
if not vidcap.isOpened(): | |
logger.error(f"Error: Cannot open video file {video_path}") | |
return [] | |
total_frames_in_video = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
extracted_filenames = [] # Stores only filenames, not full paths | |
if total_frames_in_video == 0: | |
vidcap.release() | |
return [] | |
actual_frames_to_extract = min(num_frames_to_extract, total_frames_in_video) | |
if actual_frames_to_extract == 0 and total_frames_in_video > 0: actual_frames_to_extract = 1 | |
if actual_frames_to_extract == 0: | |
vidcap.release() | |
return [] | |
for i in range(actual_frames_to_extract): | |
if progress: progress(i / actual_frames_to_extract, desc=f"Extracting frame {i+1}/{actual_frames_to_extract}") | |
frame_number = int(i * total_frames_in_video / actual_frames_to_extract) | |
frame_number = min(frame_number, total_frames_in_video - 1) | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) | |
success, image = vidcap.read() | |
if success: | |
random_filename = f"{uuid.uuid4().hex}.jpg" | |
full_frame_path = os.path.join(GRADIO_TEMP_FRAME_DIR, random_filename) | |
try: | |
cv2.imwrite(full_frame_path, image) | |
extracted_filenames.append(random_filename) | |
except Exception as e: | |
logger.error(f"Error writing frame {full_frame_path}: {e}") | |
else: | |
logger.warning(f"Warning: Failed to read frame at position {frame_number} from {video_path}") | |
vidcap.release() | |
if progress: progress(1, desc="Frame extraction complete.") | |
return extracted_filenames | |
async def process_video_gradio(video_temp_path: str, num_frames: int, progress=gr.Progress(track_tqdm=True)): | |
""" | |
Main async processing function for Gradio, using robust NSFW checkers. | |
""" | |
if APP_BASE_URL == "YOUR_GRADIO_SPACE_PUBLIC_URL_HERE": | |
warning_msg = "CRITICAL WARNING: APP_BASE_URL is not set! External NSFW checks will likely fail. Please set the APP_BASE_URL environment variable to your public Gradio Space URL." | |
logger.error(warning_msg) | |
# Optionally, display this warning in the Gradio UI as well | |
# For now, returning it in the JSON output. | |
return {"error": warning_msg, "details": "The application needs to know its own public URL to construct frame URLs for external analysis services."} | |
if not video_temp_path: | |
return {"error": "No video file provided or video path is invalid."} | |
try: | |
num_frames = int(num_frames) | |
if num_frames <= 0: | |
return {"error": "Number of frames must be a positive integer."} | |
except (ValueError, TypeError): | |
return {"error": "Invalid number for frames."} | |
# Run synchronous frame extraction in an executor to keep the async event loop unblocked | |
loop = asyncio.get_event_loop() | |
extracted_frame_filenames = await loop.run_in_executor( | |
None, extract_frames_sync, video_temp_path, num_frames, progress | |
) | |
if not extracted_frame_filenames: | |
# Cleanup the uploaded video file if it exists and extraction failed | |
if os.path.exists(video_temp_path): | |
try: os.remove(video_temp_path) | |
except Exception as e: logger.error(f"Error cleaning up input video {video_temp_path} after failed extraction: {e}") | |
return {"error": "Could not extract any frames from the video."} | |
nsfw_count = 0 | |
total_frames_processed = len(extracted_frame_filenames) | |
frame_results_output = [] | |
analysis_coroutines = [] | |
for frame_filename in extracted_frame_filenames: | |
# Construct the URL for the NSFW checker using Gradio's /file= route | |
# The path for /file= should be the absolute path on the server where Gradio can find the file. | |
absolute_frame_path_on_server = os.path.join(GRADIO_TEMP_FRAME_DIR, frame_filename) | |
publicly_accessible_frame_url = f"{APP_BASE_URL.rstrip('/')}/file={absolute_frame_path_on_server}" | |
analysis_coroutines.append(check_nsfw_final_concurrent(publicly_accessible_frame_url)) | |
# Update progress for analysis phase | |
# Since gr.Progress doesn't directly map to asyncio.gather, we'll set a general message. | |
if progress: progress(0.5, desc=f"Analyzing {total_frames_processed} frames (may take time)...") | |
nsfw_detection_results = await asyncio.gather(*analysis_coroutines, return_exceptions=True) | |
# Update progress after analysis | |
if progress: progress(0.9, desc="Compiling results...") | |
for i, detection_result in enumerate(nsfw_detection_results): | |
frame_filename = extracted_frame_filenames[i] | |
absolute_frame_path_on_server = os.path.join(GRADIO_TEMP_FRAME_DIR, frame_filename) | |
publicly_accessible_frame_url = f"{APP_BASE_URL.rstrip('/')}/file={absolute_frame_path_on_server}" | |
is_nsfw_str = "unknown" | |
if isinstance(detection_result, Exception): | |
logger.error(f"Error analyzing frame {publicly_accessible_frame_url}: {detection_result}") | |
is_nsfw_str = "error" | |
else: # detection_result is True, False, or None | |
if detection_result is True: | |
nsfw_count += 1 | |
is_nsfw_str = "true" | |
elif detection_result is False: | |
is_nsfw_str = "false" | |
frame_results_output.append({ | |
"frame_filename_on_server": frame_filename, | |
"checked_url": publicly_accessible_frame_url, | |
"nsfw_detected": is_nsfw_str | |
}) | |
if progress: progress(1, desc="Analysis complete. Cleaning up temporary files...") | |
# Cleanup extracted frames | |
for frame_filename in extracted_frame_filenames: | |
full_frame_path_to_delete = os.path.join(GRADIO_TEMP_FRAME_DIR, frame_filename) | |
if os.path.exists(full_frame_path_to_delete): | |
try: | |
os.remove(full_frame_path_to_delete) | |
except Exception as e: | |
logger.error(f"Error deleting frame {full_frame_path_to_delete}: {e}") | |
# Gradio manages the `video_temp_path` (uploaded video) cleanup. | |
final_result_json = { | |
"summary": { | |
"nsfw_frames_found": nsfw_count, | |
"total_frames_analyzed": total_frames_processed, | |
"app_base_url_used_for_checks": APP_BASE_URL, | |
"frames_temp_dir_on_server": GRADIO_TEMP_FRAME_DIR | |
}, | |
"frame_details": frame_results_output | |
} | |
return final_result_json | |
# --- Gradio Interface Definition --- | |
with gr.Blocks(css="footer {display: none !important;}", title="NSFW Video Detector") as app_interface: | |
gr.Markdown( | |
f""" | |
# NSFW Frame Detection in Video | |
Upload a video and specify the number of frames to check for NSFW content. | |
The analysis uses a series of external NSFW detection models with retries and concurrent checks for robustness. | |
""" | |
) | |
with gr.Accordion("Important: How this Space Works & Configuration", open=False): | |
gr.Markdown( | |
f""" | |
- **`APP_BASE_URL`**: For this Space to work correctly when deployed (e.g., on Hugging Face Spaces), the `APP_BASE_URL` environment variable **must** be set to its public URL (e.g., `https://your-username-your-spacename.hf.space`). This is because external NSFW checkers need to access the extracted video frames via public URLs. | |
- **Currently configured `APP_BASE_URL` for frame checking: `{APP_BASE_URL}`**. If this shows the placeholder, the NSFW checks will likely fail. | |
- **Temporary Frame Storage**: Frames are temporarily extracted to `{GRADIO_TEMP_FRAME_DIR}` on the server and deleted after processing. | |
- **Processing Time**: Depends on video length, number of frames selected, and the responsiveness of external NSFW checking services. Please be patient. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
video_input = gr.Video(label="Upload Video") | |
num_frames_input = gr.Number(label="Number of Frames to Check (e.g., 5-20)", value=10, minimum=1, maximum=50, step=1, precision=0) | |
submit_button = gr.Button("Detect NSFW Content", variant="primary") | |
with gr.Column(scale=2): | |
output_json = gr.JSON(label="Detection Result") | |
submit_button.click( | |
fn=process_video_gradio, # Use the new async handler | |
inputs=[video_input, num_frames_input], | |
outputs=output_json | |
) | |
gr.Examples( | |
examples=[ | |
# Provide path to a sample video if you have one in your Space repo | |
# Example: [os.path.join(os.path.dirname(__file__), "sample_video.mp4"), 5], | |
], | |
inputs=[video_input, num_frames_input], | |
outputs=output_json, | |
fn=process_video_gradio, # Ensure example uses the async handler too | |
cache_examples=False | |
) | |
gr.Markdown("Note: If `APP_BASE_URL` is not correctly set to this Space's public URL, the NSFW detection for frames will fail as external services won't be able to access them.") | |
if __name__ == "__main__": | |
if APP_BASE_URL == "YOUR_GRADIO_SPACE_PUBLIC_URL_HERE": # Check placeholder | |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | |
print("!!! CRITICAL WARNING: APP_BASE_URL is NOT SET or using a placeholder. !!!") | |
print("!!! External NSFW checks will likely FAIL. !!!") | |
print("!!! For local testing: Expose this app (e.g., with ngrok) and set APP_BASE_URL to the ngrok URL. !!!") | |
print("!!! When deploying to Hugging Face Spaces, set the APP_BASE_URL environment variable in Space settings.!!!") | |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | |
# For Gradio, app.launch() is typically how you start it. | |
# If running in a Hugging Face Space with a Dockerfile, the CMD usually is `python app.py`. | |
# Gradio handles the web server part. | |
app_interface.queue() # Enable queue for handling multiple requests and longer tasks better. | |
app_interface.launch() # share=True can be used for temporary public link if running locally. | |