yaya36095 commited on
Commit
963149c
·
verified ·
1 Parent(s): bc35f37

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +50 -98
handler.py CHANGED
@@ -1,21 +1,13 @@
1
  import base64
2
  import io
3
  import os
4
- import sys
5
  from typing import Dict, Any, List
6
 
7
  import torch
8
  from PIL import Image
9
- from timm import create_model
10
- from torchvision import transforms
11
- from safetensors.torch import load_file
12
 
13
  class EndpointHandler:
14
- """Custom ViT image-classifier for Hugging Face Inference Endpoints."""
15
-
16
- # --------------------------------------------------
17
- # 1) تحميل النموذج والوزن مرة واحدة
18
- # --------------------------------------------------
19
  def __init__(self, model_dir: str) -> None:
20
  print(f"بدء تهيئة النموذج من المسار: {model_dir}")
21
  print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
@@ -23,70 +15,53 @@ class EndpointHandler:
23
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  print(f"استخدام الجهاز: {self.device}")
25
 
26
- # تحديد مسارات الملفات المحتملة
27
- pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
28
- safetensors_path = os.path.join(model_dir, "model.safetensors")
29
-
30
- print(f"مسار ملف PyTorch: {pytorch_path}, موجود: {os.path.exists(pytorch_path)}")
31
- print(f"مسار ملف Safetensors: {safetensors_path}, موجود: {os.path.exists(safetensors_path)}")
32
-
33
- # إنشاء ViT Base Patch-16 بعدد فئات 5
34
  try:
35
- print("محاولة إنشاء نموذج ViT")
36
- self.model = create_model("vit_base_patch16_224", num_classes=5)
37
- print("تم إنشاء نموذج ViT بنجاح")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
- print(f"خطأ في إنشاء النموذج: {e}")
 
 
40
  raise
41
 
42
- # محاولة تحميل النموذج من pytorch_model.bin أولاً
43
- model_loaded = False
44
- if os.path.exists(pytorch_path):
45
- try:
46
- print(f"محاولة تحميل النموذج من: {pytorch_path}")
47
- state_dict = torch.load(pytorch_path, map_location="cpu")
48
- print(f"مفاتيح state_dict: {list(state_dict.keys())[:5]}...")
49
-
50
- # طباعة بنية النموذج ومفاتيح state_dict للمقارنة
51
- model_keys = set(k for k, _ in self.model.named_parameters())
52
- state_dict_keys = set(state_dict.keys())
53
- print(f"عدد مفاتيح النموذج: {len(model_keys)}")
54
- print(f"عدد مفاتيح state_dict: {len(state_dict_keys)}")
55
- print(f"المفاتيح المشتركة: {len(model_keys.intersection(state_dict_keys))}")
56
-
57
- self.model.load_state_dict(state_dict)
58
- print("تم تحميل النموذج بنجاح من pytorch_model.bin")
59
- model_loaded = True
60
- except Exception as e:
61
- print(f"خطأ في تحميل pytorch_model.bin: {e}")
62
- print(f"نوع الخطأ: {type(e).__name__}")
63
- print(f"تفاصيل الخطأ: {str(e)}")
64
-
65
- # إذا فشل تحميل pytorch_model.bin، حاول استخدام model.safetensors
66
- if not model_loaded and os.path.exists(safetensors_path):
67
- try:
68
- print(f"محاولة تحميل النموذج من: {safetensors_path}")
69
- # تحميل النموذج بدون محاولة مطابقة الهيكل مباشرة
70
- # سنقوم بتهيئة النموذج من الصفر بدلاً من ذلك
71
- print("تهيئة نموذج ViT من الصفر")
72
- # لا نحاول تحميل safetensors لأنه يحتوي على هيكل مختلف
73
- print("تم تهيئة نموذج ViT بدون أوزان مسبقة")
74
- model_loaded = True
75
- except Exception as e:
76
- print(f"خطأ في تحميل model.safetensors: {e}")
77
-
78
- if not model_loaded:
79
- print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.")
80
-
81
- self.model.eval().to(self.device)
82
- print("تم تحويل النموذج إلى وضع التقييم")
83
-
84
- # محوّلات التحضير
85
- self.preprocess = transforms.Compose([
86
- transforms.Resize((224, 224), interpolation=Image.BICUBIC),
87
- transforms.ToTensor(),
88
- ])
89
-
90
  self.labels = [
91
  "stable_diffusion",
92
  "midjourney",
@@ -94,20 +69,6 @@ class EndpointHandler:
94
  "real",
95
  "other_ai",
96
  ]
97
- print(f"تم تعريف التسميات: {self.labels}")
98
-
99
- # --------------------------------------------------
100
- # 2) دوال مساعدة
101
- # --------------------------------------------------
102
- def _img_to_tensor(self, img: Image.Image) -> torch.Tensor:
103
- try:
104
- print(f"تحويل الصورة إلى تنسور. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
105
- tensor = self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
106
- print(f"تم تحويل الصورة بنجاح. شكل التنسور: {tensor.shape}")
107
- return tensor
108
- except Exception as e:
109
- print(f"خطأ في تحويل الصورة إلى تنسور: {e}")
110
- raise
111
 
112
  def _decode_b64(self, b: bytes) -> Image.Image:
113
  try:
@@ -119,18 +80,7 @@ class EndpointHandler:
119
  print(f"خطأ في فك ترميز base64: {e}")
120
  raise
121
 
122
- # --------------------------------------------------
123
- # 3) الدالة الرئيسة
124
- # --------------------------------------------------
125
  def __call__(self, data: Any) -> List[Dict[str, Any]]:
126
- """
127
- يدعم:
128
- • Widget (PIL.Image)
129
- • REST (base64 فى data["inputs"] أو data["image"])
130
-
131
- يعيد:
132
- • مصفوفة من القواميس بتنسيق [{label: string, score: number}, ...]
133
- """
134
  print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}")
135
 
136
  img: Image.Image | None = None
@@ -154,11 +104,14 @@ class EndpointHandler:
154
  print("لم يتم العثور على صورة صالحة في البيانات")
155
  return [{"label": "error", "score": 1.0}]
156
 
 
 
 
157
  print("بدء التنبؤ باستخدام النموذج")
158
  with torch.no_grad():
159
- tensor = self._img_to_tensor(img)
160
- logits = self.model(tensor)
161
- probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
162
  print(f"تم الحصول على الاحتمالات: {probs}")
163
 
164
  # تحويل النتائج إلى التنسيق المطلوب: Array<label: string, score:number>
@@ -181,7 +134,6 @@ class EndpointHandler:
181
  print(f"حدث خطأ أثناء المعالجة: {e}")
182
  print(f"نوع الخطأ: {type(e).__name__}")
183
  print(f"تفاصيل الخطأ: {str(e)}")
184
- # تتبع الاستثناء الكامل
185
  import traceback
186
  traceback.print_exc()
187
 
 
1
  import base64
2
  import io
3
  import os
 
4
  from typing import Dict, Any, List
5
 
6
  import torch
7
  from PIL import Image
8
+ from transformers import ViTImageProcessor, ViTForImageClassification
 
 
9
 
10
  class EndpointHandler:
 
 
 
 
 
11
  def __init__(self, model_dir: str) -> None:
12
  print(f"بدء تهيئة النموذج من المسار: {model_dir}")
13
  print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
 
15
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  print(f"استخدام الجهاز: {self.device}")
17
 
18
+ # استخدام مكتبة transformers بدلاً من timm
 
 
 
 
 
 
 
19
  try:
20
+ print("تحميل معالج الصور ViT")
21
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
22
+
23
+ print("تحميل نموذج ViT")
24
+ self.model = ViTForImageClassification.from_pretrained(
25
+ "google/vit-base-patch16-224",
26
+ num_labels=5,
27
+ id2label={
28
+ 0: "stable_diffusion",
29
+ 1: "midjourney",
30
+ 2: "dalle",
31
+ 3: "real",
32
+ 4: "other_ai"
33
+ },
34
+ label2id={
35
+ "stable_diffusion": 0,
36
+ "midjourney": 1,
37
+ "dalle": 2,
38
+ "real": 3,
39
+ "other_ai": 4
40
+ }
41
+ )
42
+
43
+ # محاولة تحميل الأوزان المخصصة إذا كانت موجودة
44
+ pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
45
+ if os.path.exists(pytorch_path):
46
+ print(f"محاولة تحميل الأوزان المخصصة من: {pytorch_path}")
47
+ try:
48
+ state_dict = torch.load(pytorch_path, map_location="cpu")
49
+ # تحميل الأوزان المتوافقة فقط
50
+ self.model.load_state_dict(state_dict, strict=False)
51
+ print("تم تحميل الأوزان المخصصة بنجاح")
52
+ except Exception as e:
53
+ print(f"تحذير: فشل تحميل الأوزان المخصصة: {e}")
54
+
55
+ self.model.to(self.device)
56
+ self.model.eval()
57
+ print("تم تهيئة النموذج بنجاح")
58
+
59
  except Exception as e:
60
+ print(f"خطأ في تهيئة النموذج: {e}")
61
+ import traceback
62
+ traceback.print_exc()
63
  raise
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  self.labels = [
66
  "stable_diffusion",
67
  "midjourney",
 
69
  "real",
70
  "other_ai",
71
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def _decode_b64(self, b: bytes) -> Image.Image:
74
  try:
 
80
  print(f"خطأ في فك ترميز base64: {e}")
81
  raise
82
 
 
 
 
83
  def __call__(self, data: Any) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
84
  print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}")
85
 
86
  img: Image.Image | None = None
 
104
  print("لم يتم العثور على صورة صالحة في البيانات")
105
  return [{"label": "error", "score": 1.0}]
106
 
107
+ print("معالجة الصورة باستخدام معالج ViT")
108
+ inputs = self.processor(images=img, return_tensors="pt").to(self.device)
109
+
110
  print("بدء التنبؤ باستخدام النموذج")
111
  with torch.no_grad():
112
+ outputs = self.model(**inputs)
113
+ logits = outputs.logits
114
+ probs = torch.nn.functional.softmax(logits[0], dim=0)
115
  print(f"تم الحصول على الاحتمالات: {probs}")
116
 
117
  # تحويل النتائج إلى التنسيق المطلوب: Array<label: string, score:number>
 
134
  print(f"حدث خطأ أثناء المعالجة: {e}")
135
  print(f"نوع الخطأ: {type(e).__name__}")
136
  print(f"تفاصيل الخطأ: {str(e)}")
 
137
  import traceback
138
  traceback.print_exc()
139