yaya36095 commited on
Commit
bf94d8e
·
verified ·
1 Parent(s): 26b92af

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -30
handler.py CHANGED
@@ -3,52 +3,31 @@ import torch
3
  from torchvision import transforms
4
  from PIL import Image
5
  from safetensors.torch import load_file
6
- from timm import create_model # timm ضروري لتشغيل ViT
7
 
8
 
9
  class EndpointHandler:
10
- """Custom pipeline for Hugging Face Inference Endpoints."""
11
 
12
  def __init__(self, model_dir: str):
13
- # اختَر GPU إذا متاح
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- # تحميل الوزن بصيغة safetensors
17
- weights_path = os.path.join(model_dir, "model.safetensors")
18
- weights = load_file(weights_path)
19
-
20
- # إنشاء نموذج ViT مطابق لِما درّبتَه
21
  self.model = create_model("vit_base_patch16_224", num_classes=5)
22
  self.model.load_state_dict(weights)
23
  self.model.eval().to(self.device)
24
 
25
- # تحويـلات الصورة
26
- self.transform = transforms.Compose(
27
- [
28
- transforms.Resize((224, 224), interpolation=Image.BICUBIC),
29
- transforms.ToTensor(),
30
- ]
31
- )
32
 
33
- self.labels = [
34
- "stable_diffusion",
35
- "midjourney",
36
- "dalle",
37
- "real",
38
- "other_ai",
39
- ]
40
 
41
- # ---------- helpers ----------
42
  def _prep(self, img: Image.Image):
43
  return self.transform(img.convert("RGB")).unsqueeze(0).to(self.device)
44
 
45
- # ---------- main entry ----------
46
  def __call__(self, data):
47
- """
48
- يدعم:
49
- • Widget: يستلم PIL.Image
50
- • REST API: يستلم base64 فى data["inputs"] أو data["image"]
51
- """
52
  img = None
53
  if isinstance(data, Image.Image):
54
  img = data
@@ -67,4 +46,3 @@ class EndpointHandler:
67
  probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
68
 
69
  return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}
70
- fix: correct handler indentation
 
3
  from torchvision import transforms
4
  from PIL import Image
5
  from safetensors.torch import load_file
6
+ from timm import create_model
7
 
8
 
9
  class EndpointHandler:
10
+ """Custom pipeline for HF Inference Endpoints."""
11
 
12
  def __init__(self, model_dir: str):
 
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
+ weights = load_file(os.path.join(model_dir, "model.safetensors"))
 
 
 
 
16
  self.model = create_model("vit_base_patch16_224", num_classes=5)
17
  self.model.load_state_dict(weights)
18
  self.model.eval().to(self.device)
19
 
20
+ self.transform = transforms.Compose([
21
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
22
+ transforms.ToTensor(),
23
+ ])
 
 
 
24
 
25
+ self.labels = ["stable_diffusion", "midjourney", "dalle", "real", "other_ai"]
 
 
 
 
 
 
26
 
 
27
  def _prep(self, img: Image.Image):
28
  return self.transform(img.convert("RGB")).unsqueeze(0).to(self.device)
29
 
 
30
  def __call__(self, data):
 
 
 
 
 
31
  img = None
32
  if isinstance(data, Image.Image):
33
  img = data
 
46
  probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
47
 
48
  return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}