Update handler.py
Browse files- handler.py +50 -98
handler.py
CHANGED
@@ -1,21 +1,13 @@
|
|
1 |
import base64
|
2 |
import io
|
3 |
import os
|
4 |
-
import sys
|
5 |
from typing import Dict, Any, List
|
6 |
|
7 |
import torch
|
8 |
from PIL import Image
|
9 |
-
from
|
10 |
-
from torchvision import transforms
|
11 |
-
from safetensors.torch import load_file
|
12 |
|
13 |
class EndpointHandler:
|
14 |
-
"""Custom ViT image-classifier for Hugging Face Inference Endpoints."""
|
15 |
-
|
16 |
-
# --------------------------------------------------
|
17 |
-
# 1) تحميل النموذج والوزن مرة واحدة
|
18 |
-
# --------------------------------------------------
|
19 |
def __init__(self, model_dir: str) -> None:
|
20 |
print(f"بدء تهيئة النموذج من المسار: {model_dir}")
|
21 |
print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
|
@@ -23,70 +15,53 @@ class EndpointHandler:
|
|
23 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
print(f"استخدام الجهاز: {self.device}")
|
25 |
|
26 |
-
#
|
27 |
-
pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
|
28 |
-
safetensors_path = os.path.join(model_dir, "model.safetensors")
|
29 |
-
|
30 |
-
print(f"مسار ملف PyTorch: {pytorch_path}, موجود: {os.path.exists(pytorch_path)}")
|
31 |
-
print(f"مسار ملف Safetensors: {safetensors_path}, موجود: {os.path.exists(safetensors_path)}")
|
32 |
-
|
33 |
-
# إنشاء ViT Base Patch-16 بعدد فئات 5
|
34 |
try:
|
35 |
-
print("
|
36 |
-
self.
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
except Exception as e:
|
39 |
-
print(f"خطأ في
|
|
|
|
|
40 |
raise
|
41 |
|
42 |
-
# محاولة تحميل النموذج من pytorch_model.bin أولاً
|
43 |
-
model_loaded = False
|
44 |
-
if os.path.exists(pytorch_path):
|
45 |
-
try:
|
46 |
-
print(f"محاولة تحميل النموذج من: {pytorch_path}")
|
47 |
-
state_dict = torch.load(pytorch_path, map_location="cpu")
|
48 |
-
print(f"مفاتيح state_dict: {list(state_dict.keys())[:5]}...")
|
49 |
-
|
50 |
-
# طباعة بنية النموذج ومفاتيح state_dict للمقارنة
|
51 |
-
model_keys = set(k for k, _ in self.model.named_parameters())
|
52 |
-
state_dict_keys = set(state_dict.keys())
|
53 |
-
print(f"عدد مفاتيح النموذج: {len(model_keys)}")
|
54 |
-
print(f"عدد مفاتيح state_dict: {len(state_dict_keys)}")
|
55 |
-
print(f"المفاتيح المشتركة: {len(model_keys.intersection(state_dict_keys))}")
|
56 |
-
|
57 |
-
self.model.load_state_dict(state_dict)
|
58 |
-
print("تم تحميل النموذج بنجاح من pytorch_model.bin")
|
59 |
-
model_loaded = True
|
60 |
-
except Exception as e:
|
61 |
-
print(f"خطأ في تحميل pytorch_model.bin: {e}")
|
62 |
-
print(f"نوع الخطأ: {type(e).__name__}")
|
63 |
-
print(f"تفاصيل الخطأ: {str(e)}")
|
64 |
-
|
65 |
-
# إذا فشل تحميل pytorch_model.bin، حاول استخدام model.safetensors
|
66 |
-
if not model_loaded and os.path.exists(safetensors_path):
|
67 |
-
try:
|
68 |
-
print(f"محاولة تحميل النموذج من: {safetensors_path}")
|
69 |
-
# تحميل النموذج بدون محاولة مطابقة الهيكل مباشرة
|
70 |
-
# سنقوم بتهيئة النموذج من الصفر بدلاً من ذلك
|
71 |
-
print("تهيئة نموذج ViT من الصفر")
|
72 |
-
# لا نحاول تحميل safetensors لأنه يحتوي على هيكل مختلف
|
73 |
-
print("تم تهيئة نموذج ViT بدون أوزان مسبقة")
|
74 |
-
model_loaded = True
|
75 |
-
except Exception as e:
|
76 |
-
print(f"خطأ في تحميل model.safetensors: {e}")
|
77 |
-
|
78 |
-
if not model_loaded:
|
79 |
-
print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.")
|
80 |
-
|
81 |
-
self.model.eval().to(self.device)
|
82 |
-
print("تم تحويل النموذج إلى وضع التقييم")
|
83 |
-
|
84 |
-
# محوّلات التحضير
|
85 |
-
self.preprocess = transforms.Compose([
|
86 |
-
transforms.Resize((224, 224), interpolation=Image.BICUBIC),
|
87 |
-
transforms.ToTensor(),
|
88 |
-
])
|
89 |
-
|
90 |
self.labels = [
|
91 |
"stable_diffusion",
|
92 |
"midjourney",
|
@@ -94,20 +69,6 @@ class EndpointHandler:
|
|
94 |
"real",
|
95 |
"other_ai",
|
96 |
]
|
97 |
-
print(f"تم تعريف التسميات: {self.labels}")
|
98 |
-
|
99 |
-
# --------------------------------------------------
|
100 |
-
# 2) دوال مساعدة
|
101 |
-
# --------------------------------------------------
|
102 |
-
def _img_to_tensor(self, img: Image.Image) -> torch.Tensor:
|
103 |
-
try:
|
104 |
-
print(f"تحويل الصورة إلى تنسور. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
|
105 |
-
tensor = self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
|
106 |
-
print(f"تم تحويل الصورة بنجاح. شكل التنسور: {tensor.shape}")
|
107 |
-
return tensor
|
108 |
-
except Exception as e:
|
109 |
-
print(f"خطأ في تحويل الصورة إلى تنسور: {e}")
|
110 |
-
raise
|
111 |
|
112 |
def _decode_b64(self, b: bytes) -> Image.Image:
|
113 |
try:
|
@@ -119,18 +80,7 @@ class EndpointHandler:
|
|
119 |
print(f"خطأ في فك ترميز base64: {e}")
|
120 |
raise
|
121 |
|
122 |
-
# --------------------------------------------------
|
123 |
-
# 3) الدالة الرئيسة
|
124 |
-
# --------------------------------------------------
|
125 |
def __call__(self, data: Any) -> List[Dict[str, Any]]:
|
126 |
-
"""
|
127 |
-
يدعم:
|
128 |
-
• Widget (PIL.Image)
|
129 |
-
• REST (base64 فى data["inputs"] أو data["image"])
|
130 |
-
|
131 |
-
يعيد:
|
132 |
-
• مصفوفة من القواميس بتنسيق [{label: string, score: number}, ...]
|
133 |
-
"""
|
134 |
print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}")
|
135 |
|
136 |
img: Image.Image | None = None
|
@@ -154,11 +104,14 @@ class EndpointHandler:
|
|
154 |
print("لم يتم العثور على صورة صالحة في البيانات")
|
155 |
return [{"label": "error", "score": 1.0}]
|
156 |
|
|
|
|
|
|
|
157 |
print("بدء التنبؤ باستخدام النموذج")
|
158 |
with torch.no_grad():
|
159 |
-
|
160 |
-
logits =
|
161 |
-
probs = torch.nn.functional.softmax(logits
|
162 |
print(f"تم الحصول على الاحتمالات: {probs}")
|
163 |
|
164 |
# تحويل النتائج إلى التنسيق المطلوب: Array<label: string, score:number>
|
@@ -181,7 +134,6 @@ class EndpointHandler:
|
|
181 |
print(f"حدث خطأ أثناء المعالجة: {e}")
|
182 |
print(f"نوع الخطأ: {type(e).__name__}")
|
183 |
print(f"تفاصيل الخطأ: {str(e)}")
|
184 |
-
# تتبع الاستثناء الكامل
|
185 |
import traceback
|
186 |
traceback.print_exc()
|
187 |
|
|
|
1 |
import base64
|
2 |
import io
|
3 |
import os
|
|
|
4 |
from typing import Dict, Any, List
|
5 |
|
6 |
import torch
|
7 |
from PIL import Image
|
8 |
+
from transformers import ViTImageProcessor, ViTForImageClassification
|
|
|
|
|
9 |
|
10 |
class EndpointHandler:
|
|
|
|
|
|
|
|
|
|
|
11 |
def __init__(self, model_dir: str) -> None:
|
12 |
print(f"بدء تهيئة النموذج من المسار: {model_dir}")
|
13 |
print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
|
|
|
15 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
print(f"استخدام الجهاز: {self.device}")
|
17 |
|
18 |
+
# استخدام مكتبة transformers بدلاً من timm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
try:
|
20 |
+
print("تحميل معالج الصور ViT")
|
21 |
+
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
22 |
+
|
23 |
+
print("تحميل نموذج ViT")
|
24 |
+
self.model = ViTForImageClassification.from_pretrained(
|
25 |
+
"google/vit-base-patch16-224",
|
26 |
+
num_labels=5,
|
27 |
+
id2label={
|
28 |
+
0: "stable_diffusion",
|
29 |
+
1: "midjourney",
|
30 |
+
2: "dalle",
|
31 |
+
3: "real",
|
32 |
+
4: "other_ai"
|
33 |
+
},
|
34 |
+
label2id={
|
35 |
+
"stable_diffusion": 0,
|
36 |
+
"midjourney": 1,
|
37 |
+
"dalle": 2,
|
38 |
+
"real": 3,
|
39 |
+
"other_ai": 4
|
40 |
+
}
|
41 |
+
)
|
42 |
+
|
43 |
+
# محاولة تحميل الأوزان المخصصة إذا كانت موجودة
|
44 |
+
pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
|
45 |
+
if os.path.exists(pytorch_path):
|
46 |
+
print(f"محاولة تحميل الأوزان المخصصة من: {pytorch_path}")
|
47 |
+
try:
|
48 |
+
state_dict = torch.load(pytorch_path, map_location="cpu")
|
49 |
+
# تحميل الأوزان المتوافقة فقط
|
50 |
+
self.model.load_state_dict(state_dict, strict=False)
|
51 |
+
print("تم تحميل الأوزان المخصصة بنجاح")
|
52 |
+
except Exception as e:
|
53 |
+
print(f"تحذير: فشل تحميل الأوزان المخصصة: {e}")
|
54 |
+
|
55 |
+
self.model.to(self.device)
|
56 |
+
self.model.eval()
|
57 |
+
print("تم تهيئة النموذج بنجاح")
|
58 |
+
|
59 |
except Exception as e:
|
60 |
+
print(f"خطأ في تهيئة النموذج: {e}")
|
61 |
+
import traceback
|
62 |
+
traceback.print_exc()
|
63 |
raise
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
self.labels = [
|
66 |
"stable_diffusion",
|
67 |
"midjourney",
|
|
|
69 |
"real",
|
70 |
"other_ai",
|
71 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
def _decode_b64(self, b: bytes) -> Image.Image:
|
74 |
try:
|
|
|
80 |
print(f"خطأ في فك ترميز base64: {e}")
|
81 |
raise
|
82 |
|
|
|
|
|
|
|
83 |
def __call__(self, data: Any) -> List[Dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}")
|
85 |
|
86 |
img: Image.Image | None = None
|
|
|
104 |
print("لم يتم العثور على صورة صالحة في البيانات")
|
105 |
return [{"label": "error", "score": 1.0}]
|
106 |
|
107 |
+
print("معالجة الصورة باستخدام معالج ViT")
|
108 |
+
inputs = self.processor(images=img, return_tensors="pt").to(self.device)
|
109 |
+
|
110 |
print("بدء التنبؤ باستخدام النموذج")
|
111 |
with torch.no_grad():
|
112 |
+
outputs = self.model(**inputs)
|
113 |
+
logits = outputs.logits
|
114 |
+
probs = torch.nn.functional.softmax(logits[0], dim=0)
|
115 |
print(f"تم الحصول على الاحتمالات: {probs}")
|
116 |
|
117 |
# تحويل النتائج إلى التنسيق المطلوب: Array<label: string, score:number>
|
|
|
134 |
print(f"حدث خطأ أثناء المعالجة: {e}")
|
135 |
print(f"نوع الخطأ: {type(e).__name__}")
|
136 |
print(f"تفاصيل الخطأ: {str(e)}")
|
|
|
137 |
import traceback
|
138 |
traceback.print_exc()
|
139 |
|