File size: 3,249 Bytes
dd32056 00fa5d2 bf94d8e dd32056 00fa5d2 8ea56b1 dd32056 8ea56b1 dd32056 00fa5d2 dd32056 00fa5d2 dd32056 00fa5d2 dd32056 00fa5d2 dd32056 00fa5d2 dd32056 00fa5d2 dd32056 00fa5d2 dd32056 00fa5d2 dd32056 00fa5d2 dd32056 00fa5d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import base64
import io
import os
from typing import Dict, Any
import torch
from PIL import Image
from safetensors.torch import load_file
from timm import create_model
from torchvision import transforms
class EndpointHandler:
"""Custom image-classification pipeline for Hugging Face Inference Endpoints."""
# --------------------------------------------------
# 1) تحميل النموذج والوزن مرة واحدة عند تشغيل الـ Endpoint
# --------------------------------------------------
def __init__(self, model_dir: str) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# وزن محفوظ بصيغة safetensors
weights_path = os.path.join(model_dir, "model.safetensors")
state_dict = load_file(weights_path)
# أنشئ نفس معماريّة ViT التى درّبتها (num_classes = 5)
self.model = create_model("vit_base_patch16_224", num_classes=5)
self.model.load_state_dict(state_dict)
self.model.eval().to(self.device)
# محوّلات التحضير
self.preprocess = transforms.Compose(
[
transforms.Resize((224, 224), interpolation=Image.BICUBIC),
transforms.ToTensor(),
]
)
self.labels = [
"stable_diffusion",
"midjourney",
"dalle",
"real",
"other_ai",
]
# --------------------------------------------------
# 2) دوال مساعدة
# --------------------------------------------------
def _image_from_bytes(self, b: bytes) -> Image.Image:
"""decode base64 → PIL"""
return Image.open(io.BytesIO(base64.b64decode(b)))
def _to_tensor(self, img: Image.Image) -> torch.Tensor:
"""PIL → tensor (1 × 3 × 224 × 224) على نفس الجهاز"""
return self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
# --------------------------------------------------
# 3) الدالة الرئيسة التى تستدعيها المنصّة لكل طلب
# --------------------------------------------------
def __call__(self, data: Any) -> Dict[str, float]:
"""
يدعم:
• Widget — يمرّر PIL.Image مباشرةً
• REST — يمرّر dict وفيه مفتاح "inputs" أو "image" (base64)
"""
# — الحصول على صورة PIL —
img: Image.Image | None = None
if isinstance(data, Image.Image):
img = data
elif isinstance(data, dict):
payload = data.get("inputs") or data.get("image")
if isinstance(payload, (str, bytes)):
if isinstance(payload, str):
payload = payload.encode()
img = self._image_from_bytes(payload)
if img is None:
return {"error": "No image provided"}
# — الاستدلال —
with torch.no_grad():
logits = self.model(self._to_tensor(img))
probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}
|