|
import base64 |
|
import io |
|
import os |
|
from typing import Dict, Any |
|
|
|
import torch |
|
from PIL import Image |
|
from safetensors.torch import load_file |
|
from timm import create_model |
|
from torchvision import transforms |
|
|
|
|
|
class EndpointHandler: |
|
"""Custom image-classification pipeline for Hugging Face Inference Endpoints.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, model_dir: str) -> None: |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
weights_path = os.path.join(model_dir, "model.safetensors") |
|
state_dict = load_file(weights_path) |
|
|
|
|
|
self.model = create_model("vit_base_patch16_224", num_classes=5) |
|
self.model.load_state_dict(state_dict) |
|
self.model.eval().to(self.device) |
|
|
|
|
|
self.preprocess = transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224), interpolation=Image.BICUBIC), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
|
|
self.labels = [ |
|
"stable_diffusion", |
|
"midjourney", |
|
"dalle", |
|
"real", |
|
"other_ai", |
|
] |
|
|
|
|
|
|
|
|
|
def _image_from_bytes(self, b: bytes) -> Image.Image: |
|
"""decode base64 → PIL""" |
|
return Image.open(io.BytesIO(base64.b64decode(b))) |
|
|
|
def _to_tensor(self, img: Image.Image) -> torch.Tensor: |
|
"""PIL → tensor (1 × 3 × 224 × 224) على نفس الجهاز""" |
|
return self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
|
def __call__(self, data: Any) -> Dict[str, float]: |
|
""" |
|
يدعم: |
|
• Widget — يمرّر PIL.Image مباشرةً |
|
• REST — يمرّر dict وفيه مفتاح "inputs" أو "image" (base64) |
|
""" |
|
|
|
img: Image.Image | None = None |
|
if isinstance(data, Image.Image): |
|
img = data |
|
elif isinstance(data, dict): |
|
payload = data.get("inputs") or data.get("image") |
|
if isinstance(payload, (str, bytes)): |
|
if isinstance(payload, str): |
|
payload = payload.encode() |
|
img = self._image_from_bytes(payload) |
|
|
|
if img is None: |
|
return {"error": "No image provided"} |
|
|
|
|
|
with torch.no_grad(): |
|
logits = self.model(self._to_tensor(img)) |
|
probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0) |
|
|
|
return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))} |
|
|