Update handler.py
Browse files- handler.py +96 -30
handler.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import base64
|
2 |
import io
|
3 |
import os
|
|
|
4 |
from typing import Dict, Any, List
|
5 |
|
6 |
import torch
|
@@ -16,14 +17,27 @@ class EndpointHandler:
|
|
16 |
# 1) تحميل النموذج والوزن مرة واحدة
|
17 |
# --------------------------------------------------
|
18 |
def __init__(self, model_dir: str) -> None:
|
|
|
|
|
|
|
19 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
20 |
|
21 |
# تحديد مسارات الملفات المحتملة
|
22 |
pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
|
23 |
safetensors_path = os.path.join(model_dir, "model.safetensors")
|
24 |
|
|
|
|
|
|
|
25 |
# إنشاء ViT Base Patch-16 بعدد فئات 5
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# محاولة تحميل النموذج من pytorch_model.bin أولاً
|
29 |
model_loaded = False
|
@@ -31,11 +45,22 @@ class EndpointHandler:
|
|
31 |
try:
|
32 |
print(f"محاولة تحميل النموذج من: {pytorch_path}")
|
33 |
state_dict = torch.load(pytorch_path, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
self.model.load_state_dict(state_dict)
|
35 |
print("تم تحميل النموذج بنجاح من pytorch_model.bin")
|
36 |
model_loaded = True
|
37 |
except Exception as e:
|
38 |
print(f"خطأ في تحميل pytorch_model.bin: {e}")
|
|
|
|
|
39 |
|
40 |
# إذا فشل تحميل pytorch_model.bin، حاول استخدام model.safetensors
|
41 |
if not model_loaded and os.path.exists(safetensors_path):
|
@@ -54,6 +79,7 @@ class EndpointHandler:
|
|
54 |
print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.")
|
55 |
|
56 |
self.model.eval().to(self.device)
|
|
|
57 |
|
58 |
# محوّلات التحضير
|
59 |
self.preprocess = transforms.Compose([
|
@@ -68,15 +94,30 @@ class EndpointHandler:
|
|
68 |
"real",
|
69 |
"other_ai",
|
70 |
]
|
|
|
71 |
|
72 |
# --------------------------------------------------
|
73 |
# 2) دوال مساعدة
|
74 |
# --------------------------------------------------
|
75 |
def _img_to_tensor(self, img: Image.Image) -> torch.Tensor:
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
def _decode_b64(self, b: bytes) -> Image.Image:
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# --------------------------------------------------
|
82 |
# 3) ��لدالة الرئيسة
|
@@ -90,33 +131,58 @@ class EndpointHandler:
|
|
90 |
يعيد:
|
91 |
• مصفوفة من القواميس بتنسيق [{label: string, score: number}, ...]
|
92 |
"""
|
|
|
|
|
93 |
img: Image.Image | None = None
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import base64
|
2 |
import io
|
3 |
import os
|
4 |
+
import sys
|
5 |
from typing import Dict, Any, List
|
6 |
|
7 |
import torch
|
|
|
17 |
# 1) تحميل النموذج والوزن مرة واحدة
|
18 |
# --------------------------------------------------
|
19 |
def __init__(self, model_dir: str) -> None:
|
20 |
+
print(f"بدء تهيئة النموذج من المسار: {model_dir}")
|
21 |
+
print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
|
22 |
+
|
23 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
print(f"استخدام الجهاز: {self.device}")
|
25 |
|
26 |
# تحديد مسارات الملفات المحتملة
|
27 |
pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
|
28 |
safetensors_path = os.path.join(model_dir, "model.safetensors")
|
29 |
|
30 |
+
print(f"مسار ملف PyTorch: {pytorch_path}, موجود: {os.path.exists(pytorch_path)}")
|
31 |
+
print(f"مسار ملف Safetensors: {safetensors_path}, موجود: {os.path.exists(safetensors_path)}")
|
32 |
+
|
33 |
# إنشاء ViT Base Patch-16 بعدد فئات 5
|
34 |
+
try:
|
35 |
+
print("محاولة إنشاء نموذج ViT")
|
36 |
+
self.model = create_model("vit_base_patch16_224", num_classes=5)
|
37 |
+
print("تم إنشاء نموذج ViT بنجاح")
|
38 |
+
except Exception as e:
|
39 |
+
print(f"خطأ في إنشاء النموذج: {e}")
|
40 |
+
raise
|
41 |
|
42 |
# محاولة تحميل النموذج من pytorch_model.bin أولاً
|
43 |
model_loaded = False
|
|
|
45 |
try:
|
46 |
print(f"محاولة تحميل النموذج من: {pytorch_path}")
|
47 |
state_dict = torch.load(pytorch_path, map_location="cpu")
|
48 |
+
print(f"مفاتيح state_dict: {list(state_dict.keys())[:5]}...")
|
49 |
+
|
50 |
+
# طباعة بنية النموذج ومفاتيح state_dict للمقارنة
|
51 |
+
model_keys = set(k for k, _ in self.model.named_parameters())
|
52 |
+
state_dict_keys = set(state_dict.keys())
|
53 |
+
print(f"عدد مفاتيح النموذج: {len(model_keys)}")
|
54 |
+
print(f"عدد مفاتيح state_dict: {len(state_dict_keys)}")
|
55 |
+
print(f"المفاتيح المشتركة: {len(model_keys.intersection(state_dict_keys))}")
|
56 |
+
|
57 |
self.model.load_state_dict(state_dict)
|
58 |
print("تم تحميل النموذج بنجاح من pytorch_model.bin")
|
59 |
model_loaded = True
|
60 |
except Exception as e:
|
61 |
print(f"خطأ في تحميل pytorch_model.bin: {e}")
|
62 |
+
print(f"نوع الخطأ: {type(e).__name__}")
|
63 |
+
print(f"تفاصيل الخطأ: {str(e)}")
|
64 |
|
65 |
# إذا فشل تحميل pytorch_model.bin، حاول استخدام model.safetensors
|
66 |
if not model_loaded and os.path.exists(safetensors_path):
|
|
|
79 |
print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.")
|
80 |
|
81 |
self.model.eval().to(self.device)
|
82 |
+
print("تم تحويل النموذج إلى وضع التقييم")
|
83 |
|
84 |
# محوّلات التحضير
|
85 |
self.preprocess = transforms.Compose([
|
|
|
94 |
"real",
|
95 |
"other_ai",
|
96 |
]
|
97 |
+
print(f"تم تعريف التسميات: {self.labels}")
|
98 |
|
99 |
# --------------------------------------------------
|
100 |
# 2) دوال مساعدة
|
101 |
# --------------------------------------------------
|
102 |
def _img_to_tensor(self, img: Image.Image) -> torch.Tensor:
|
103 |
+
try:
|
104 |
+
print(f"تحويل الصورة إلى تنسور. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
|
105 |
+
tensor = self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
|
106 |
+
print(f"تم تحويل الصورة بنجاح. شكل التنسور: {tensor.shape}")
|
107 |
+
return tensor
|
108 |
+
except Exception as e:
|
109 |
+
print(f"خطأ في تحويل الصورة إلى تنسور: {e}")
|
110 |
+
raise
|
111 |
|
112 |
def _decode_b64(self, b: bytes) -> Image.Image:
|
113 |
+
try:
|
114 |
+
print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت")
|
115 |
+
img = Image.open(io.BytesIO(base64.b64decode(b)))
|
116 |
+
print(f"تم فك الترميز بنجاح. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
|
117 |
+
return img
|
118 |
+
except Exception as e:
|
119 |
+
print(f"خطأ في فك ترميز base64: {e}")
|
120 |
+
raise
|
121 |
|
122 |
# --------------------------------------------------
|
123 |
# 3) ��لدالة الرئيسة
|
|
|
131 |
يعيد:
|
132 |
• مصفوفة من القواميس بتنسيق [{label: string, score: number}, ...]
|
133 |
"""
|
134 |
+
print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}")
|
135 |
+
|
136 |
img: Image.Image | None = None
|
137 |
|
138 |
+
try:
|
139 |
+
if isinstance(data, Image.Image):
|
140 |
+
print("البيانات هي صورة PIL")
|
141 |
+
img = data
|
142 |
+
elif isinstance(data, dict):
|
143 |
+
print(f"البيانات هي قاموس بالمفاتيح: {list(data.keys())}")
|
144 |
+
payload = data.get("inputs") or data.get("image")
|
145 |
+
print(f"نوع الحمولة: {type(payload)}")
|
146 |
+
|
147 |
+
if isinstance(payload, (str, bytes)):
|
148 |
+
if isinstance(payload, str):
|
149 |
+
print("تحويل السلسلة النصية إلى بايت")
|
150 |
+
payload = payload.encode()
|
151 |
+
img = self._decode_b64(payload)
|
152 |
+
|
153 |
+
if img is None:
|
154 |
+
print("لم يتم العثور على صورة صالحة في البيانات")
|
155 |
+
return [{"label": "error", "score": 1.0}]
|
156 |
+
|
157 |
+
print("بدء التنبؤ باستخدام النموذج")
|
158 |
+
with torch.no_grad():
|
159 |
+
tensor = self._img_to_tensor(img)
|
160 |
+
logits = self.model(tensor)
|
161 |
+
probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
|
162 |
+
print(f"تم الحصول على الاحتمالات: {probs}")
|
163 |
+
|
164 |
+
# تحويل النتائج إلى التنسيق المطلوب: Array<label: string, score:number>
|
165 |
+
results = []
|
166 |
+
for i, label in enumerate(self.labels):
|
167 |
+
score = float(probs[i])
|
168 |
+
print(f"التسمية: {label}, الدرجة: {score}")
|
169 |
+
results.append({
|
170 |
+
"label": label,
|
171 |
+
"score": score
|
172 |
+
})
|
173 |
+
|
174 |
+
# ترتيب النتائج تنازلياً حسب درجة الثقة
|
175 |
+
results.sort(key=lambda x: x["score"], reverse=True)
|
176 |
+
print(f"النتائج النهائية: {results}")
|
177 |
+
|
178 |
+
return results
|
179 |
|
180 |
+
except Exception as e:
|
181 |
+
print(f"حدث خطأ أثناء المعالجة: {e}")
|
182 |
+
print(f"نوع الخطأ: {type(e).__name__}")
|
183 |
+
print(f"تفاصيل الخطأ: {str(e)}")
|
184 |
+
# تتبع الاستثناء الكامل
|
185 |
+
import traceback
|
186 |
+
traceback.print_exc()
|
187 |
+
|
188 |
+
return [{"label": "error", "score": 1.0}]
|