yaya36095 commited on
Commit
dd32056
·
verified ·
1 Parent(s): ab2108c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +60 -21
handler.py CHANGED
@@ -1,48 +1,87 @@
1
- import base64, io, os
 
 
 
 
2
  import torch
3
- from torchvision import transforms
4
  from PIL import Image
5
  from safetensors.torch import load_file
6
  from timm import create_model
 
7
 
8
 
9
  class EndpointHandler:
10
- """Custom pipeline for HF Inference Endpoints."""
11
 
12
- def __init__(self, model_dir: str):
 
 
 
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
- weights = load_file(os.path.join(model_dir, "model.safetensors"))
 
 
 
 
16
  self.model = create_model("vit_base_patch16_224", num_classes=5)
17
- self.model.load_state_dict(weights)
18
  self.model.eval().to(self.device)
19
 
20
- self.transform = transforms.Compose([
21
- transforms.Resize((224, 224), interpolation=Image.BICUBIC),
22
- transforms.ToTensor(),
23
- ])
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- self.labels = ["stable_diffusion", "midjourney", "dalle", "real", "other_ai"]
 
 
 
 
 
26
 
27
- def _prep(self, img: Image.Image):
28
- return self.transform(img.convert("RGB")).unsqueeze(0).to(self.device)
 
29
 
30
- def __call__(self, data):
31
- img = None
 
 
 
 
 
 
 
 
 
32
  if isinstance(data, Image.Image):
33
  img = data
34
  elif isinstance(data, dict):
35
- b64 = data.get("inputs") or data.get("image")
36
- if isinstance(b64, (str, bytes)):
37
- if isinstance(b64, str):
38
- b64 = b64.encode()
39
- img = Image.open(io.BytesIO(base64.b64decode(b64)))
40
 
41
  if img is None:
42
  return {"error": "No image provided"}
43
 
 
44
  with torch.no_grad():
45
- logits = self.model(self._prep(img))
46
  probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
47
 
48
  return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ from typing import Dict, Any
5
+
6
  import torch
 
7
  from PIL import Image
8
  from safetensors.torch import load_file
9
  from timm import create_model
10
+ from torchvision import transforms
11
 
12
 
13
  class EndpointHandler:
14
+ """Custom image-classification pipeline for Hugging Face Inference Endpoints."""
15
 
16
+ # --------------------------------------------------
17
+ # 1) تحميل النموذج والوزن مرة واحدة عند تشغيل الـ Endpoint
18
+ # --------------------------------------------------
19
+ def __init__(self, model_dir: str) -> None:
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ # وزن محفوظ بصيغة safetensors
23
+ weights_path = os.path.join(model_dir, "model.safetensors")
24
+ state_dict = load_file(weights_path)
25
+
26
+ # أنشئ نفس معماريّة ViT التى درّبتها (num_classes = 5)
27
  self.model = create_model("vit_base_patch16_224", num_classes=5)
28
+ self.model.load_state_dict(state_dict)
29
  self.model.eval().to(self.device)
30
 
31
+ # محوّلات التحضير
32
+ self.preprocess = transforms.Compose(
33
+ [
34
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
35
+ transforms.ToTensor(),
36
+ ]
37
+ )
38
+
39
+ self.labels = [
40
+ "stable_diffusion",
41
+ "midjourney",
42
+ "dalle",
43
+ "real",
44
+ "other_ai",
45
+ ]
46
 
47
+ # --------------------------------------------------
48
+ # 2) دوال مساعدة
49
+ # --------------------------------------------------
50
+ def _image_from_bytes(self, b: bytes) -> Image.Image:
51
+ """decode base64 → PIL"""
52
+ return Image.open(io.BytesIO(base64.b64decode(b)))
53
 
54
+ def _to_tensor(self, img: Image.Image) -> torch.Tensor:
55
+ """PIL → tensor (1 × 3 × 224 × 224) على نفس الجهاز"""
56
+ return self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
57
 
58
+ # --------------------------------------------------
59
+ # 3) الدالة الرئيسة التى تستدعيها المنصّة لكل طلب
60
+ # --------------------------------------------------
61
+ def __call__(self, data: Any) -> Dict[str, float]:
62
+ """
63
+ يدعم:
64
+ • Widget — يمرّر PIL.Image مباشرةً
65
+ • REST — يمرّر dict وفيه مفتاح "inputs" أو "image" (base64)
66
+ """
67
+ # — الحصول على صورة PIL —
68
+ img: Image.Image | None = None
69
  if isinstance(data, Image.Image):
70
  img = data
71
  elif isinstance(data, dict):
72
+ payload = data.get("inputs") or data.get("image")
73
+ if isinstance(payload, (str, bytes)):
74
+ if isinstance(payload, str):
75
+ payload = payload.encode()
76
+ img = self._image_from_bytes(payload)
77
 
78
  if img is None:
79
  return {"error": "No image provided"}
80
 
81
+ # — الاستدلال —
82
  with torch.no_grad():
83
+ logits = self.model(self._to_tensor(img))
84
  probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
85
 
86
  return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}
87
+ fix: proper indentation for handler