File size: 3,249 Bytes
dd32056
 
 
 
 
00fa5d2
 
 
bf94d8e
dd32056
00fa5d2
 
8ea56b1
dd32056
8ea56b1
dd32056
 
 
 
00fa5d2
 
dd32056
 
 
 
 
00fa5d2
dd32056
00fa5d2
 
dd32056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00fa5d2
dd32056
 
 
 
 
 
00fa5d2
dd32056
 
 
00fa5d2
dd32056
 
 
 
 
 
 
 
 
 
 
00fa5d2
 
 
dd32056
 
 
 
 
00fa5d2
 
 
 
dd32056
00fa5d2
dd32056
00fa5d2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import base64
import io
import os
from typing import Dict, Any

import torch
from PIL import Image
from safetensors.torch import load_file
from timm import create_model
from torchvision import transforms


class EndpointHandler:
    """Custom image-classification pipeline for Hugging Face Inference Endpoints."""

    # --------------------------------------------------
    # 1) تحميل النموذج والوزن مرة واحدة عند تشغيل الـ Endpoint
    # --------------------------------------------------
    def __init__(self, model_dir: str) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # وزن محفوظ بصيغة safetensors
        weights_path = os.path.join(model_dir, "model.safetensors")
        state_dict = load_file(weights_path)

        # أنشئ نفس معماريّة ViT التى درّبتها (num_classes = 5)
        self.model = create_model("vit_base_patch16_224", num_classes=5)
        self.model.load_state_dict(state_dict)
        self.model.eval().to(self.device)

        # محوّلات التحضير
        self.preprocess = transforms.Compose(
            [
                transforms.Resize((224, 224), interpolation=Image.BICUBIC),
                transforms.ToTensor(),
            ]
        )

        self.labels = [
            "stable_diffusion",
            "midjourney",
            "dalle",
            "real",
            "other_ai",
        ]

    # --------------------------------------------------
    # 2) دوال مساعدة
    # --------------------------------------------------
    def _image_from_bytes(self, b: bytes) -> Image.Image:
        """decode base64 → PIL"""
        return Image.open(io.BytesIO(base64.b64decode(b)))

    def _to_tensor(self, img: Image.Image) -> torch.Tensor:
        """PIL → tensor (1 × 3 × 224 × 224) على نفس الجهاز"""
        return self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)

    # --------------------------------------------------
    # 3) الدالة الرئيسة التى تستدعيها المنصّة لكل طلب
    # --------------------------------------------------
    def __call__(self, data: Any) -> Dict[str, float]:
        """
        يدعم:
        • Widget — يمرّر PIL.Image مباشرةً
        • REST — يمرّر dict وفيه مفتاح "inputs" أو "image" (base64)
        """
        # — الحصول على صورة PIL —
        img: Image.Image | None = None
        if isinstance(data, Image.Image):
            img = data
        elif isinstance(data, dict):
            payload = data.get("inputs") or data.get("image")
            if isinstance(payload, (str, bytes)):
                if isinstance(payload, str):
                    payload = payload.encode()
                img = self._image_from_bytes(payload)

        if img is None:
            return {"error": "No image provided"}

        # — الاستدلال —
        with torch.no_grad():
            logits = self.model(self._to_tensor(img))
            probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)

        return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}