ai-source-detector / handler.py
yaya36095's picture
Update handler.py
b10f28b verified
raw
history blame
3.25 kB
import base64
import io
import os
from typing import Dict, Any
import torch
from PIL import Image
from safetensors.torch import load_file
from timm import create_model
from torchvision import transforms
class EndpointHandler:
"""Custom image-classification pipeline for Hugging Face Inference Endpoints."""
# --------------------------------------------------
# 1) تحميل النموذج والوزن مرة واحدة عند تشغيل الـ Endpoint
# --------------------------------------------------
def __init__(self, model_dir: str) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# وزن محفوظ بصيغة safetensors
weights_path = os.path.join(model_dir, "model.safetensors")
state_dict = load_file(weights_path)
# أنشئ نفس معماريّة ViT التى درّبتها (num_classes = 5)
self.model = create_model("vit_base_patch16_224", num_classes=5)
self.model.load_state_dict(state_dict)
self.model.eval().to(self.device)
# محوّلات التحضير
self.preprocess = transforms.Compose(
[
transforms.Resize((224, 224), interpolation=Image.BICUBIC),
transforms.ToTensor(),
]
)
self.labels = [
"stable_diffusion",
"midjourney",
"dalle",
"real",
"other_ai",
]
# --------------------------------------------------
# 2) دوال مساعدة
# --------------------------------------------------
def _image_from_bytes(self, b: bytes) -> Image.Image:
"""decode base64 → PIL"""
return Image.open(io.BytesIO(base64.b64decode(b)))
def _to_tensor(self, img: Image.Image) -> torch.Tensor:
"""PIL → tensor (1 × 3 × 224 × 224) على نفس الجهاز"""
return self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
# --------------------------------------------------
# 3) الدالة الرئيسة التى تستدعيها المنصّة لكل طلب
# --------------------------------------------------
def __call__(self, data: Any) -> Dict[str, float]:
"""
يدعم:
• Widget — يمرّر PIL.Image مباشرةً
• REST — يمرّر dict وفيه مفتاح "inputs" أو "image" (base64)
"""
# — الحصول على صورة PIL —
img: Image.Image | None = None
if isinstance(data, Image.Image):
img = data
elif isinstance(data, dict):
payload = data.get("inputs") or data.get("image")
if isinstance(payload, (str, bytes)):
if isinstance(payload, str):
payload = payload.encode()
img = self._image_from_bytes(payload)
if img is None:
return {"error": "No image provided"}
# — الاستدلال —
with torch.no_grad():
logits = self.model(self._to_tensor(img))
probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}