sdafd's picture
Update app.py
a681854 verified
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import cv2
import uuid
import os
import requests
import random
import string
import json
import shutil
import ast
import tempfile
import base64
app = FastAPI()
def check_nsfw(img_url):
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(7))
data = {
'data': [
{'path': img_url},
"chen-convnext",
0.5,
True,
True
],
'session_hash': session_hash,
'fn_index': 0,
'trigger_id': 12
}
r = requests.post('https://yoinked-da-nsfw-checker.hf.space/queue/join', json=data)
r = requests.get(f'https://yoinked-da-nsfw-checker.hf.space/queue/data?session_hash={session_hash}', stream=True)
buffer = "" # Buffer to accumulate the chunks
for content in r.iter_content(100):
# Decode the byte content to a string
buffer += content.decode('utf-8')
print(buffer)
return json.loads(buffer.split('data:')[len(buffer.split('data:'))-2])["output"]["data"][0]['label']
def check_nsfw2(img_url):
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(7))
data = {
'data': [
{'path': img_url}
],
'session_hash': session_hash,
'fn_index': 0,
'trigger_id': 9
}
r = requests.post('https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/join', json=data)
r = requests.get(f'https://jamescookjr90-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}', stream=True)
buffer = "" # Buffer to accumulate the chunks
for content in r.iter_content(100):
# Decode the byte content to a string
buffer += content.decode('utf-8')
print(buffer)
return json.loads(buffer.split('data:')[len(buffer.split('data:'))-2])["output"]["data"][0]['label']
def check_nsfw3(img_url):
data = {
'data': [
{'path': img_url}
]
}
r = requests.post('https://zanderlewis-xl-nsfw-detection.hf.space/call/predict',json=data)
json_data = r.json()
event_id = json_data['event_id']
r = requests.get(f'https://zanderlewis-xl-nsfw-detection.hf.space/call/predict/{event_id}', stream=True)
event_stream = ''
for chunk in r.iter_content(100):
event_stream += chunk.decode('utf-8')
print(event_stream)
return ast.literal_eval(event_stream.split('data:')[-1])[0]['label']
def get_replica_code(url):
try:
r = requests.get(url)
return r.text.split('replicas/')[1].split('"};')[0]
except:
return None
def check_nsfw4(img_url):
code = get_replica_code('https://error466-falconsai-nsfw-image-detection.hf.space')
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(7))
data = {
'data': [
{'path': img_url}
],
'session_hash': session_hash,
'fn_index': 0,
'trigger_id': 58
}
r = requests.post(f'https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/join', json=data)
r = requests.get(f'https://error466-falconsai-nsfw-image-detection.hf.space/--replicas/{code}/queue/data?session_hash={session_hash}', stream=True)
buffer = "" # Buffer to accumulate the chunks
for content in r.iter_content(100):
# Decode the byte content to a string
buffer += content.decode('utf-8')
print(buffer)
return json.loads(buffer.split('data:')[-1])["output"]["data"][0]['label']
def check_nsfw5(img_url):
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(7))
data = {
'data': [
{'path': img_url}
],
'session_hash': session_hash,
'fn_index': 0,
'trigger_id': 9
}
r = requests.post('https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/join', json=data)
r = requests.get(f'https://phelpsgg-falconsai-nsfw-image-detection.hf.space/queue/data?session_hash={session_hash}', stream=True)
buffer = "" # Buffer to accumulate the chunks
for content in r.iter_content(100):
# Decode the byte content to a string
buffer += content.decode('utf-8')
print(buffer)
return json.loads(buffer.split('data:')[-1])["output"]["data"][0]['label']
def check_nsfw_final(img_url):
result = None
print(img_url)
try:
check = check_nsfw2(img_url)
if check == 'nsfw':
result = True
else:
result = False
except Exception as e:
print(e)
if result is None:
try:
check = check_nsfw3(img_url)
if check == 'nsfw':
result = True
else:
result = False
except Exception as e:
print(e)
if result is None:
try:
check = check_nsfw4(img_url)
if check == 'nsfw':
result = True
else:
result = False
except:
pass
if result is None:
try:
check = check_nsfw5(img_url)
if check == 'nsfw':
result = True
else:
result = False
except Exception as e:
print(e)
return result
# Directory for serving frame images
FRAMES_DIR = "/tmp/frames"
os.makedirs(FRAMES_DIR, exist_ok=True)
# Mount static file route
app.mount("/frames", StaticFiles(directory=FRAMES_DIR), name="frames")
# Frame extraction
def extract_frames(video_path, num_frames, temp_dir):
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = []
for i in range(num_frames):
frame_number = int(i * total_frames / num_frames)
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
success, image = vidcap.read()
if success:
frame_filename = os.path.join(temp_dir, f"{uuid.uuid4().hex}.jpg")
cv2.imwrite(frame_filename, image)
frames.append(frame_filename)
else:
break
vidcap.release()
return frames
# Video processing
def process_video(video_path, num_frames):
os.makedirs(FRAMES_DIR, exist_ok=True) # Ensure frames directory exists
# Extract frames to the static files directory
frames = extract_frames(video_path, num_frames, FRAMES_DIR)
nsfw_count = 0
total_frames = len(frames)
frame_results = []
for frame_path in frames:
# Construct HTTP URL for the frame
frame_filename = os.path.basename(frame_path)
img_url = f"https://sdafd-nsfw-video-detect-api.hf.space/frames/{frame_filename}"
nsfw_detected = check_nsfw_final(img_url)
frame_results.append({
"frame_path": img_url, # Use HTTP URL
"nsfw_detected": nsfw_detected
})
if nsfw_detected:
nsfw_count += 1
result = {
"nsfw_count": nsfw_count,
"total_frames": total_frames,
"frames": frame_results
}
return result
# Request/Response models
class VideoRequest(BaseModel):
video_url: str
num_frames: int = 10
class VideoResponse(BaseModel):
nsfw_count: int
total_frames: int
frames: list
# API Endpoints
@app.post("/process_video", response_model=VideoResponse)
async def process_video_endpoint(request: VideoRequest):
# Download video
try:
video_path = os.path.join("/tmp", f"{uuid.uuid4().hex}.mp4")
with requests.get(request.video_url, stream=True) as r:
r.raise_for_status()
with open(video_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error downloading video: {e}")
try:
# Process video
result = process_video(video_path, request.num_frames)
finally:
# Cleanup video file
if os.path.exists(video_path):
os.remove(video_path)
return result