yhay360 commited on
Commit
00fa5d2
·
1 Parent(s): c1dc91f

feat: add EndpointHandler

Browse files
Files changed (1) hide show
  1. handler.py +47 -0
handler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64, io, os
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from safetensors.torch import load_file
6
+ from timm import create_model # timm ضرورى للتعامل مع ViT
7
+
8
+
9
+ class EndpointHandler: # اسم الفئة مهم جداً
10
+ def __init__(self, model_dir: str):
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # تحميل الوزن بصيغة safetensors
14
+ weights = load_file(os.path.join(model_dir, "model.safetensors"))
15
+ self.model = create_model("vit_base_patch16_224", num_classes=5)
16
+ self.model.load_state_dict(weights)
17
+ self.model.eval().to(self.device)
18
+
19
+ self.transform = transforms.Compose([
20
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
21
+ transforms.ToTensor(),
22
+ ])
23
+
24
+ self.labels = ['stable_diffusion', 'midjourney', 'dalle', 'real', 'other_ai']
25
+
26
+ def _prep(self, img: Image.Image):
27
+ return self.transform(img.convert("RGB")).unsqueeze(0).to(self.device)
28
+
29
+ def __call__(self, data):
30
+ # يدعم: Widget (PIL) أو REST (base64)
31
+ img = None
32
+ if isinstance(data, Image.Image):
33
+ img = data
34
+ elif isinstance(data, dict):
35
+ b = data.get("inputs") or data.get("image")
36
+ if isinstance(b, (str, bytes)):
37
+ b = b.encode() if isinstance(b, str) else b
38
+ img = Image.open(io.BytesIO(base64.b64decode(b)))
39
+
40
+ if img is None:
41
+ return {"error": "No image provided"}
42
+
43
+ with torch.no_grad():
44
+ logits = self.model(self._prep(img))
45
+ probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
46
+
47
+ return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}