imseldrith's picture
Update app.py
b03b138 verified
import gradio as gr
import cv2 # OpenCV for video processing
import os # For interacting with the file system
import requests # For making HTTP requests
import random
import string
import json
import shutil # For file operations
import ast # For safely evaluating string literals
import uuid # For generating unique filenames
import asyncio # For concurrent operations
import time # For retries and delays
import logging # For structured logging
# --- Logging Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Configuration ---
# Directory for temporarily storing extracted frames.
GRADIO_TEMP_FRAME_DIR = "/tmp/gradio_nsfw_frames_advanced"
os.makedirs(GRADIO_TEMP_FRAME_DIR, exist_ok=True)
# The public URL of this Gradio Space. Crucial for external NSFW checkers.
# Set via environment variable or update placeholder if hardcoding.
APP_BASE_URL = os.getenv("APP_BASE_URL", "YOUR_GRADIO_SPACE_PUBLIC_URL_HERE")
DEFAULT_REQUEST_TIMEOUT = 20
MAX_RETRY_ATTEMPTS = 3
RETRY_BACKOFF_FACTOR = 2
# --- NSFW Checker Configuration (from FastAPI version) ---
NSFW_CHECKER_CONFIG = {
"checker1_yoinked": {
"queue_join_url": "https://yoinked-da-nsfw-checker.hf.space/queue/join",
"queue_data_url_template": "https://yoinked-da-nsfw-checker.hf.space/queue/data?session_hash={session_hash}",
"payload_template": lambda img_url, session_hash: {
'data': [{'path': img_url}, "chen-convnext", 0.5, True, True],
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 12
}
},
"checker2_jamescookjr90": {
"queue_join_url": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/join",
"queue_data_url_template": "https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}",
"payload_template": lambda img_url, session_hash: {
'data': [{'path': img_url}],
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9
}
},
"checker3_zanderlewis": {
"predict_url": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict",
"event_url_template": "https://zanderlewis-xl-nsfw-detection.hf.space/call/predict/{event_id}",
"payload_template": lambda img_url: {'data': [{'path': img_url}]}
},
"checker4_error466": {
"base_url": "https://error466-falconsai-nsfw-image-detection.hf.space",
"replica_code_needed": True,
"queue_join_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/join",
"queue_data_url_template": "https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/data?session_hash={session_hash}",
"payload_template": lambda img_url, session_hash: {
'data': [{'path': img_url}],
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 58
}
},
"checker5_phelpsgg": { # Original user code had 'imseldrith', FastAPI had 'phelpsgg'. Using phelpsgg from FastAPI.
"queue_join_url": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/join",
"queue_data_url_template": "https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}",
"payload_template": lambda img_url, session_hash: {
'data': [{'path': img_url}],
'session_hash': session_hash, 'fn_index': 0, 'trigger_id': 9
}
}
}
# --- Helper Functions (from FastAPI version, adapted for Gradio context) ---
async def http_request_with_retry(method: str, url: str, **kwargs) -> Optional[requests.Response]:
headers = kwargs.pop("headers", {})
headers.setdefault("User-Agent", "GradioNSFWClient/1.0")
for attempt in range(MAX_RETRY_ATTEMPTS):
try:
async with asyncio.Semaphore(10):
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: requests.request(method, url, headers=headers, timeout=DEFAULT_REQUEST_TIMEOUT, **kwargs)
)
response.raise_for_status()
return response
except requests.exceptions.Timeout:
logger.warning(f"Request timeout for {url} on attempt {attempt + 1}")
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code in [429, 502, 503, 504]:
logger.warning(f"HTTP error {e.response.status_code} for {url} on attempt {attempt + 1}")
else:
logger.error(f"Non-retriable HTTP error for {url}: {e}")
return e.response if e.response is not None else None
except requests.exceptions.RequestException as e:
logger.error(f"Request exception for {url} on attempt {attempt + 1}: {e}")
if attempt < MAX_RETRY_ATTEMPTS - 1:
delay = (RETRY_BACKOFF_FACTOR ** attempt) + random.uniform(0, 0.5)
logger.info(f"Retrying {url} in {delay:.2f} seconds...")
await asyncio.sleep(delay)
logger.error(f"All {MAX_RETRY_ATTEMPTS} retry attempts failed for {url}.")
return None
def get_replica_code_sync(url: str) -> Optional[str]:
try:
r = requests.get(url, timeout=DEFAULT_REQUEST_TIMEOUT, headers={"User-Agent": "GradioNSFWClient/1.0"})
r.raise_for_status()
parts = r.text.split('replicas/')
if len(parts) > 1:
return parts[1].split('"};')[0]
logger.warning(f"Could not find 'replicas/' in content from {url}")
return None
except (requests.exceptions.RequestException, IndexError, KeyError) as e:
logger.error(f"Error getting replica code for {url}: {e}")
return None
async def get_replica_code(url: str) -> Optional[str]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, get_replica_code_sync, url)
async def parse_hf_queue_response(response_content: str) -> Optional[str]:
try:
messages = response_content.strip().split('\n')
for msg_str in reversed(messages):
if msg_str.startswith("data:"):
try:
data_json_str = msg_str[len("data:"):].strip()
if not data_json_str: continue
parsed_json = json.loads(data_json_str)
if parsed_json.get("msg") == "process_completed":
output_data = parsed_json.get("output", {}).get("data")
if output_data and isinstance(output_data, list) and len(output_data) > 0:
first_item = output_data[0]
if isinstance(first_item, dict): return first_item.get('label')
if isinstance(first_item, str): return first_item
logger.warning(f"Unexpected 'process_completed' data structure: {output_data}")
return None
except json.JSONDecodeError:
logger.debug(f"Failed to decode JSON from part of HF stream: {data_json_str[:100]}")
continue
return None
except Exception as e:
logger.error(f"Error parsing HF queue response: {e}, content: {response_content[:200]}")
return None
async def check_nsfw_single_generic(checker_name: str, img_url: str) -> Optional[str]:
config = NSFW_CHECKER_CONFIG.get(checker_name)
if not config:
logger.error(f"No configuration found for checker: {checker_name}")
return None
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(10))
try:
if "predict_url" in config: # ZanderLewis-like
payload = config["payload_template"](img_url)
response_predict = await http_request_with_retry("POST", config["predict_url"], json=payload)
if not response_predict or response_predict.status_code != 200:
logger.error(f"{checker_name} predict call failed. Status: {response_predict.status_code if response_predict else 'N/A'}")
return None
json_data = response_predict.json()
event_id = json_data.get('event_id')
if not event_id:
logger.error(f"{checker_name} did not return event_id.")
return None
event_url = config["event_url_template"].format(event_id=event_id)
for _ in range(10):
await asyncio.sleep(random.uniform(1.5, 2.5))
response_event = await http_request_with_retry("GET", event_url) # Removed stream=True as iter_content not used directly
if response_event and response_event.status_code == 200:
event_stream_content = response_event.text
if 'data:' in event_stream_content:
final_data_str = event_stream_content.strip().split('data:')[-1].strip()
if final_data_str:
try:
parsed_list = ast.literal_eval(final_data_str)
if isinstance(parsed_list, list) and parsed_list and isinstance(parsed_list[0], dict):
return parsed_list[0].get('label')
logger.warning(f"{checker_name} parsed non-list or empty list from event stream: {final_data_str[:100]}")
except (SyntaxError, ValueError, IndexError, TypeError) as e:
logger.warning(f"{checker_name} error parsing event stream: {e}, stream: {final_data_str[:100]}")
elif response_event:
logger.warning(f"{checker_name} polling event_url returned status {response_event.status_code}")
else:
logger.warning(f"{checker_name} polling event_url got no response.")
else: # Queue-based APIs
join_url = config["queue_join_url"]
data_url_template = config["queue_data_url_template"]
if config.get("replica_code_needed"):
replica_base_url = config.get("base_url")
if not replica_base_url:
logger.error(f"{checker_name} needs replica_code but base_url is missing.")
return None
code = await get_replica_code(replica_base_url)
if not code:
logger.error(f"Failed to get replica code for {checker_name}")
return None
join_url = config["queue_join_url_template"].format(code=code)
data_url = data_url_template.format(code=code, session_hash=session_hash)
else:
data_url = data_url_template.format(session_hash=session_hash)
payload = config["payload_template"](img_url, session_hash)
response_join = await http_request_with_retry("POST", join_url, json=payload)
if not response_join or response_join.status_code != 200:
logger.error(f"{checker_name} queue/join call failed. Status: {response_join.status_code if response_join else 'N/A'}")
return None
for _ in range(15):
await asyncio.sleep(random.uniform(1.5, 2.5))
response_data = await http_request_with_retry("GET", data_url, stream=True)
if response_data and response_data.status_code == 200:
buffer = ""
# iter_content is synchronous, but http_request_with_retry runs it in executor
for content_chunk in response_data.iter_content(chunk_size=1024, decode_unicode=True):
if content_chunk:
buffer += content_chunk
if buffer.strip().endswith("}\n\n"):
label = await parse_hf_queue_response(buffer) # parse_hf_queue_response is async
if label: return label
buffer = ""
elif response_data:
logger.warning(f"{checker_name} polling queue/data returned status {response_data.status_code}")
else:
logger.warning(f"{checker_name} polling queue/data got no response.")
logger.warning(f"{checker_name} failed to get a conclusive result for {img_url}")
return None
except Exception as e:
logger.error(f"Exception in {checker_name} for {img_url}: {e}", exc_info=True)
return None
async def check_nsfw_final_concurrent(img_url: str) -> Optional[bool]:
logger.info(f"Starting NSFW check for: {img_url}")
# Prioritized list from FastAPI version
checker_names = [
"checker2_jamescookjr90", "checker3_zanderlewis", "checker5_phelpsgg",
"checker4_error466", "checker1_yoinked"
]
named_tasks = {
name: asyncio.create_task(check_nsfw_single_generic(name, img_url))
for name in checker_names
}
sfw_found_by_any_checker = False
# Iterate and await tasks. Since as_completed is not used, order of results depends on await order.
# For true "first result wins" or concurrent processing, as_completed or gather is better.
# This simplified loop awaits them one by one based on checker_names order.
for task_name in checker_names:
try:
label = await named_tasks[task_name]
logger.info(f"Checker '{task_name}' result for {img_url}: {label}")
if label:
label_lower = label.lower()
if 'nsfw' in label_lower:
logger.info(f"NSFW detected by '{task_name}' for {img_url}. Final: True.")
# Cancel remaining tasks
for t_name_to_cancel, t_obj_to_cancel in named_tasks.items():
if t_name_to_cancel != task_name and not t_obj_to_cancel.done():
t_obj_to_cancel.cancel()
return True
if 'sfw' in label_lower or 'safe' in label_lower:
sfw_found_by_any_checker = True
except asyncio.CancelledError:
logger.info(f"Checker '{task_name}' was cancelled for {img_url}.")
except Exception as e:
logger.error(f"Error processing result from checker '{task_name}' for {img_url}: {e}")
if sfw_found_by_any_checker:
logger.info(f"SFW confirmed for {img_url} (no NSFW detected by any checker, at least one SFW). Final: False.")
return False
logger.warning(f"All NSFW checkers inconclusive or failed for {img_url}. Final: None.")
return None
# --- Gradio Specific Functions ---
def extract_frames_sync(video_path: str, num_frames_to_extract: int, progress:gr.Progress=None) -> list:
if progress: progress(0, desc="Starting frame extraction...")
vidcap = cv2.VideoCapture(video_path)
if not vidcap.isOpened():
logger.error(f"Error: Cannot open video file {video_path}")
return []
total_frames_in_video = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
extracted_filenames = [] # Stores only filenames, not full paths
if total_frames_in_video == 0:
vidcap.release()
return []
actual_frames_to_extract = min(num_frames_to_extract, total_frames_in_video)
if actual_frames_to_extract == 0 and total_frames_in_video > 0: actual_frames_to_extract = 1
if actual_frames_to_extract == 0:
vidcap.release()
return []
for i in range(actual_frames_to_extract):
if progress: progress(i / actual_frames_to_extract, desc=f"Extracting frame {i+1}/{actual_frames_to_extract}")
frame_number = int(i * total_frames_in_video / actual_frames_to_extract)
frame_number = min(frame_number, total_frames_in_video - 1)
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
success, image = vidcap.read()
if success:
random_filename = f"{uuid.uuid4().hex}.jpg"
full_frame_path = os.path.join(GRADIO_TEMP_FRAME_DIR, random_filename)
try:
cv2.imwrite(full_frame_path, image)
extracted_filenames.append(random_filename)
except Exception as e:
logger.error(f"Error writing frame {full_frame_path}: {e}")
else:
logger.warning(f"Warning: Failed to read frame at position {frame_number} from {video_path}")
vidcap.release()
if progress: progress(1, desc="Frame extraction complete.")
return extracted_filenames
async def process_video_gradio(video_temp_path: str, num_frames: int, progress=gr.Progress(track_tqdm=True)):
"""
Main async processing function for Gradio, using robust NSFW checkers.
"""
if APP_BASE_URL == "YOUR_GRADIO_SPACE_PUBLIC_URL_HERE":
warning_msg = "CRITICAL WARNING: APP_BASE_URL is not set! External NSFW checks will likely fail. Please set the APP_BASE_URL environment variable to your public Gradio Space URL."
logger.error(warning_msg)
# Optionally, display this warning in the Gradio UI as well
# For now, returning it in the JSON output.
return {"error": warning_msg, "details": "The application needs to know its own public URL to construct frame URLs for external analysis services."}
if not video_temp_path:
return {"error": "No video file provided or video path is invalid."}
try:
num_frames = int(num_frames)
if num_frames <= 0:
return {"error": "Number of frames must be a positive integer."}
except (ValueError, TypeError):
return {"error": "Invalid number for frames."}
# Run synchronous frame extraction in an executor to keep the async event loop unblocked
loop = asyncio.get_event_loop()
extracted_frame_filenames = await loop.run_in_executor(
None, extract_frames_sync, video_temp_path, num_frames, progress
)
if not extracted_frame_filenames:
# Cleanup the uploaded video file if it exists and extraction failed
if os.path.exists(video_temp_path):
try: os.remove(video_temp_path)
except Exception as e: logger.error(f"Error cleaning up input video {video_temp_path} after failed extraction: {e}")
return {"error": "Could not extract any frames from the video."}
nsfw_count = 0
total_frames_processed = len(extracted_frame_filenames)
frame_results_output = []
analysis_coroutines = []
for frame_filename in extracted_frame_filenames:
# Construct the URL for the NSFW checker using Gradio's /file= route
# The path for /file= should be the absolute path on the server where Gradio can find the file.
absolute_frame_path_on_server = os.path.join(GRADIO_TEMP_FRAME_DIR, frame_filename)
publicly_accessible_frame_url = f"{APP_BASE_URL.rstrip('/')}/file={absolute_frame_path_on_server}"
analysis_coroutines.append(check_nsfw_final_concurrent(publicly_accessible_frame_url))
# Update progress for analysis phase
# Since gr.Progress doesn't directly map to asyncio.gather, we'll set a general message.
if progress: progress(0.5, desc=f"Analyzing {total_frames_processed} frames (may take time)...")
nsfw_detection_results = await asyncio.gather(*analysis_coroutines, return_exceptions=True)
# Update progress after analysis
if progress: progress(0.9, desc="Compiling results...")
for i, detection_result in enumerate(nsfw_detection_results):
frame_filename = extracted_frame_filenames[i]
absolute_frame_path_on_server = os.path.join(GRADIO_TEMP_FRAME_DIR, frame_filename)
publicly_accessible_frame_url = f"{APP_BASE_URL.rstrip('/')}/file={absolute_frame_path_on_server}"
is_nsfw_str = "unknown"
if isinstance(detection_result, Exception):
logger.error(f"Error analyzing frame {publicly_accessible_frame_url}: {detection_result}")
is_nsfw_str = "error"
else: # detection_result is True, False, or None
if detection_result is True:
nsfw_count += 1
is_nsfw_str = "true"
elif detection_result is False:
is_nsfw_str = "false"
frame_results_output.append({
"frame_filename_on_server": frame_filename,
"checked_url": publicly_accessible_frame_url,
"nsfw_detected": is_nsfw_str
})
if progress: progress(1, desc="Analysis complete. Cleaning up temporary files...")
# Cleanup extracted frames
for frame_filename in extracted_frame_filenames:
full_frame_path_to_delete = os.path.join(GRADIO_TEMP_FRAME_DIR, frame_filename)
if os.path.exists(full_frame_path_to_delete):
try:
os.remove(full_frame_path_to_delete)
except Exception as e:
logger.error(f"Error deleting frame {full_frame_path_to_delete}: {e}")
# Gradio manages the `video_temp_path` (uploaded video) cleanup.
final_result_json = {
"summary": {
"nsfw_frames_found": nsfw_count,
"total_frames_analyzed": total_frames_processed,
"app_base_url_used_for_checks": APP_BASE_URL,
"frames_temp_dir_on_server": GRADIO_TEMP_FRAME_DIR
},
"frame_details": frame_results_output
}
return final_result_json
# --- Gradio Interface Definition ---
with gr.Blocks(css="footer {display: none !important;}", title="NSFW Video Detector") as app_interface:
gr.Markdown(
f"""
# NSFW Frame Detection in Video
Upload a video and specify the number of frames to check for NSFW content.
The analysis uses a series of external NSFW detection models with retries and concurrent checks for robustness.
"""
)
with gr.Accordion("Important: How this Space Works & Configuration", open=False):
gr.Markdown(
f"""
- **`APP_BASE_URL`**: For this Space to work correctly when deployed (e.g., on Hugging Face Spaces), the `APP_BASE_URL` environment variable **must** be set to its public URL (e.g., `https://your-username-your-spacename.hf.space`). This is because external NSFW checkers need to access the extracted video frames via public URLs.
- **Currently configured `APP_BASE_URL` for frame checking: `{APP_BASE_URL}`**. If this shows the placeholder, the NSFW checks will likely fail.
- **Temporary Frame Storage**: Frames are temporarily extracted to `{GRADIO_TEMP_FRAME_DIR}` on the server and deleted after processing.
- **Processing Time**: Depends on video length, number of frames selected, and the responsiveness of external NSFW checking services. Please be patient.
"""
)
with gr.Row():
with gr.Column(scale=1):
video_input = gr.Video(label="Upload Video")
num_frames_input = gr.Number(label="Number of Frames to Check (e.g., 5-20)", value=10, minimum=1, maximum=50, step=1, precision=0)
submit_button = gr.Button("Detect NSFW Content", variant="primary")
with gr.Column(scale=2):
output_json = gr.JSON(label="Detection Result")
submit_button.click(
fn=process_video_gradio, # Use the new async handler
inputs=[video_input, num_frames_input],
outputs=output_json
)
gr.Examples(
examples=[
# Provide path to a sample video if you have one in your Space repo
# Example: [os.path.join(os.path.dirname(__file__), "sample_video.mp4"), 5],
],
inputs=[video_input, num_frames_input],
outputs=output_json,
fn=process_video_gradio, # Ensure example uses the async handler too
cache_examples=False
)
gr.Markdown("Note: If `APP_BASE_URL` is not correctly set to this Space's public URL, the NSFW detection for frames will fail as external services won't be able to access them.")
if __name__ == "__main__":
if APP_BASE_URL == "YOUR_GRADIO_SPACE_PUBLIC_URL_HERE": # Check placeholder
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!! CRITICAL WARNING: APP_BASE_URL is NOT SET or using a placeholder. !!!")
print("!!! External NSFW checks will likely FAIL. !!!")
print("!!! For local testing: Expose this app (e.g., with ngrok) and set APP_BASE_URL to the ngrok URL. !!!")
print("!!! When deploying to Hugging Face Spaces, set the APP_BASE_URL environment variable in Space settings.!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
# For Gradio, app.launch() is typically how you start it.
# If running in a Hugging Face Space with a Dockerfile, the CMD usually is `python app.py`.
# Gradio handles the web server part.
app_interface.queue() # Enable queue for handling multiple requests and longer tasks better.
app_interface.launch() # share=True can be used for temporary public link if running locally.