Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
import cv2 | |
import uuid | |
import os | |
import requests | |
import random | |
import string | |
import json | |
import shutil | |
import ast | |
import tempfile | |
import base64 | |
app = FastAPI() | |
def check_nsfw(img_url): | |
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(7)) | |
data = { | |
'data': [ | |
{'path': img_url}, | |
"chen-convnext", | |
0.5, | |
True, | |
True | |
], | |
'session_hash': session_hash, | |
'fn_index': 0, | |
'trigger_id': 12 | |
} | |
r = requests.post('https://yoinked-da-nsfw-checker.hf.space/queue/join', json=data) | |
r = requests.get(f'https://yoinked-da-nsfw-checker.hf.space/queue/data?session_hash={session_hash}', stream=True) | |
buffer = "" # Buffer to accumulate the chunks | |
for content in r.iter_content(100): | |
# Decode the byte content to a string | |
buffer += content.decode('utf-8') | |
print(buffer) | |
return json.loads(buffer.split('data:')[len(buffer.split('data:'))-2])["output"]["data"][0]['label'] | |
def check_nsfw2(img_url): | |
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(7)) | |
data = { | |
'data': [ | |
{'path': img_url} | |
], | |
'session_hash': session_hash, | |
'fn_index': 0, | |
'trigger_id': 9 | |
} | |
r = requests.post('https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/join', json=data) | |
r = requests.get(f'https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}', stream=True) | |
buffer = "" # Buffer to accumulate the chunks | |
for content in r.iter_content(100): | |
# Decode the byte content to a string | |
buffer += content.decode('utf-8') | |
print(buffer) | |
return json.loads(buffer.split('data:')[len(buffer.split('data:'))-2])["output"]["data"][0]['label'] | |
def check_nsfw3(img_url): | |
data = { | |
'data': [ | |
{'path': img_url} | |
] | |
} | |
r = requests.post('https://zanderlewis-xl-nsfw-detection.hf.space/call/predict',json=data) | |
json_data = r.json() | |
event_id = json_data['event_id'] | |
r = requests.get(f'https://zanderlewis-xl-nsfw-detection.hf.space/call/predict/{event_id}', stream=True) | |
event_stream = '' | |
for chunk in r.iter_content(100): | |
event_stream += chunk.decode('utf-8') | |
print(event_stream) | |
return ast.literal_eval(event_stream.split('data:')[-1])[0]['label'] | |
def get_replica_code(url): | |
try: | |
r = requests.get(url) | |
return r.text.split('replicas/')[1].split('"};')[0] | |
except: | |
return None | |
def check_nsfw4(img_url): | |
code = get_replica_code('https://error466-falconsai-nsfw-image-detection.hf.space') | |
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(7)) | |
data = { | |
'data': [ | |
{'path': img_url} | |
], | |
'session_hash': session_hash, | |
'fn_index': 0, | |
'trigger_id': 58 | |
} | |
r = requests.post(f'https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/join', json=data) | |
r = requests.get(f'https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/data?session_hash={session_hash}', stream=True) | |
buffer = "" # Buffer to accumulate the chunks | |
for content in r.iter_content(100): | |
# Decode the byte content to a string | |
buffer += content.decode('utf-8') | |
print(buffer) | |
return json.loads(buffer.split('data:')[-1])["output"]["data"][0]['label'] | |
def check_nsfw5(img_url): | |
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(7)) | |
data = { | |
'data': [ | |
{'path': img_url} | |
], | |
'session_hash': session_hash, | |
'fn_index': 0, | |
'trigger_id': 9 | |
} | |
r = requests.post('https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/join', json=data) | |
r = requests.get(f'https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}', stream=True) | |
buffer = "" # Buffer to accumulate the chunks | |
for content in r.iter_content(100): | |
# Decode the byte content to a string | |
buffer += content.decode('utf-8') | |
print(buffer) | |
return json.loads(buffer.split('data:')[-1])["output"]["data"][0]['label'] | |
def check_nsfw_final(img_url): | |
result = None | |
print(img_url) | |
try: | |
check = check_nsfw2(img_url) | |
if check == 'nsfw': | |
result = True | |
else: | |
result = False | |
except Exception as e: | |
print(e) | |
if result is None: | |
try: | |
check = check_nsfw3(img_url) | |
if check == 'nsfw': | |
result = True | |
else: | |
result = False | |
except Exception as e: | |
print(e) | |
if result is None: | |
try: | |
check = check_nsfw4(img_url) | |
if check == 'nsfw': | |
result = True | |
else: | |
result = False | |
except: | |
pass | |
if result is None: | |
try: | |
check = check_nsfw5(img_url) | |
if check == 'nsfw': | |
result = True | |
else: | |
result = False | |
except Exception as e: | |
print(e) | |
return result | |
# Directory for serving frame images | |
FRAMES_DIR = "/tmp/frames" | |
os.makedirs(FRAMES_DIR, exist_ok=True) | |
# Mount static file route | |
app.mount("/frames", StaticFiles(directory=FRAMES_DIR), name="frames") | |
# Frame extraction | |
def extract_frames(video_path, num_frames, temp_dir): | |
vidcap = cv2.VideoCapture(video_path) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frames = [] | |
for i in range(num_frames): | |
frame_number = int(i * total_frames / num_frames) | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) | |
success, image = vidcap.read() | |
if success: | |
frame_filename = os.path.join(temp_dir, f"{uuid.uuid4().hex}.jpg") | |
cv2.imwrite(frame_filename, image) | |
frames.append(frame_filename) | |
else: | |
break | |
vidcap.release() | |
return frames | |
# Video processing | |
def process_video(video_path, num_frames): | |
os.makedirs(FRAMES_DIR, exist_ok=True) # Ensure frames directory exists | |
# Extract frames to the static files directory | |
frames = extract_frames(video_path, num_frames, FRAMES_DIR) | |
nsfw_count = 0 | |
total_frames = len(frames) | |
frame_results = [] | |
for frame_path in frames: | |
# Construct HTTP URL for the frame | |
frame_filename = os.path.basename(frame_path) | |
img_url = f"https://sdafd-nsfw-video-detect-api.hf.space/frames/{frame_filename}" | |
nsfw_detected = check_nsfw_final(img_url) | |
frame_results.append({ | |
"frame_path": img_url, # Use HTTP URL | |
"nsfw_detected": nsfw_detected | |
}) | |
if nsfw_detected: | |
nsfw_count += 1 | |
result = { | |
"nsfw_count": nsfw_count, | |
"total_frames": total_frames, | |
"frames": frame_results | |
} | |
return result | |
# Request/Response models | |
class VideoRequest(BaseModel): | |
video_url: str | |
num_frames: int = 10 | |
class VideoResponse(BaseModel): | |
nsfw_count: int | |
total_frames: int | |
frames: list | |
# API Endpoints | |
async def process_video_endpoint(request: VideoRequest): | |
# Download video | |
try: | |
video_path = os.path.join("/tmp", f"{uuid.uuid4().hex}.mp4") | |
with requests.get(request.video_url, stream=True) as r: | |
r.raise_for_status() | |
with open(video_path, "wb") as f: | |
for chunk in r.iter_content(chunk_size=8192): | |
f.write(chunk) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error downloading video: {e}") | |
try: | |
# Process video | |
result = process_video(video_path, request.num_frames) | |
finally: | |
# Cleanup video file | |
if os.path.exists(video_path): | |
os.remove(video_path) | |
return result | |