import base64 import io import os from typing import Dict, Any import torch from PIL import Image from safetensors.torch import load_file from timm import create_model from torchvision import transforms class EndpointHandler: """Custom image-classification pipeline for Hugging Face Inference Endpoints.""" # -------------------------------------------------- # 1) تحميل النموذج والوزن مرة واحدة عند تشغيل الـ Endpoint # -------------------------------------------------- def __init__(self, model_dir: str) -> None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # وزن محفوظ بصيغة safetensors weights_path = os.path.join(model_dir, "model.safetensors") state_dict = load_file(weights_path) # أنشئ نفس معماريّة ViT التى درّبتها (num_classes = 5) self.model = create_model("vit_base_patch16_224", num_classes=5) self.model.load_state_dict(state_dict) self.model.eval().to(self.device) # محوّلات التحضير self.preprocess = transforms.Compose( [ transforms.Resize((224, 224), interpolation=Image.BICUBIC), transforms.ToTensor(), ] ) self.labels = [ "stable_diffusion", "midjourney", "dalle", "real", "other_ai", ] # -------------------------------------------------- # 2) دوال مساعدة # -------------------------------------------------- def _image_from_bytes(self, b: bytes) -> Image.Image: """decode base64 → PIL""" return Image.open(io.BytesIO(base64.b64decode(b))) def _to_tensor(self, img: Image.Image) -> torch.Tensor: """PIL → tensor (1 × 3 × 224 × 224) على نفس الجهاز""" return self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device) # -------------------------------------------------- # 3) الدالة الرئيسة التى تستدعيها المنصّة لكل طلب # -------------------------------------------------- def __call__(self, data: Any) -> Dict[str, float]: """ يدعم: • Widget — يمرّر PIL.Image مباشرةً • REST — يمرّر dict وفيه مفتاح "inputs" أو "image" (base64) """ # — الحصول على صورة PIL — img: Image.Image | None = None if isinstance(data, Image.Image): img = data elif isinstance(data, dict): payload = data.get("inputs") or data.get("image") if isinstance(payload, (str, bytes)): if isinstance(payload, str): payload = payload.encode() img = self._image_from_bytes(payload) if img is None: return {"error": "No image provided"} # — الاستدلال — with torch.no_grad(): logits = self.model(self._to_tensor(img)) probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0) return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}