yaya36095 commited on
Commit
b50aa1b
·
verified ·
1 Parent(s): ed7f1c1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -75
handler.py CHANGED
@@ -2,7 +2,6 @@ import base64
2
  import io
3
  import os
4
  from typing import Dict, Any, List
5
-
6
  import torch
7
  from PIL import Image
8
  from transformers import ViTImageProcessor, ViTForImageClassification
@@ -11,15 +10,14 @@ class EndpointHandler:
11
  def __init__(self, model_dir: str) -> None:
12
  print(f"بدء تهيئة النموذج من المسار: {model_dir}")
13
  print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
14
-
15
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  print(f"استخدام الجهاز: {self.device}")
17
-
18
- # استخدام مكتبة transformers بدلاً من timm
19
  try:
20
  print("تحميل معالج الصور ViT")
21
  self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
22
-
23
  print("تحميل نموذج ViT")
24
  self.model = ViTForImageClassification.from_pretrained(
25
  "google/vit-base-patch16-224",
@@ -38,104 +36,68 @@ class EndpointHandler:
38
  "real": 3,
39
  "other_ai": 4
40
  },
41
- ignore_mismatched_sizes=True # إضافة هذه المعلمة لتجاهل عدم تطابق الأحجام
42
  )
43
-
44
- # محاولة تحميل الأوزان المخصصة إذا كانت موجودة
45
- pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
46
- if os.path.exists(pytorch_path):
47
- print(f"محاولة تحميل الأوزان المخصصة من: {pytorch_path}")
48
- try:
49
- state_dict = torch.load(pytorch_path, map_location="cpu")
50
- # تحميل الأوزان المتوافقة فقط
51
- self.model.load_state_dict(state_dict, strict=False)
52
- print("تم تحميل الأوزان المخصصة بنجاح")
53
- except Exception as e:
54
- print(f"تحذير: فشل تحميل الأوزان المخصصة: {e}")
55
-
56
- self.model.to(self.device)
57
- self.model.eval()
58
- print("تم تهيئة النموذج بنجاح")
59
-
60
  except Exception as e:
61
- print(f"خطأ في تهيئة النموذج: {e}")
62
- import traceback
63
- traceback.print_exc()
64
  raise
65
-
66
- self.labels = [
67
- "stable_diffusion",
68
- "midjourney",
69
- "dalle",
70
- "real",
71
- "other_ai",
72
- ]
73
 
74
  def _decode_b64(self, b: bytes) -> Image.Image:
75
  try:
76
  print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت")
77
- img = Image.open(io.BytesIO(base64.b64decode(b)))
78
- print(f"تم فك الترميز بنجاح. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
79
- return img
80
  except Exception as e:
81
- print(f"خطأ في فك ترميز base64: {e}")
82
  raise
83
 
84
  def __call__(self, data: Any) -> List[Dict[str, Any]]:
85
- print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}")
86
-
87
- img: Image.Image | None = None
88
 
 
89
  try:
90
  if isinstance(data, Image.Image):
91
- print("البيانات هي صورة PIL")
92
  img = data
93
  elif isinstance(data, dict):
94
- print(f"البيانات هي قاموس بالمفاتيح: {list(data.keys())}")
95
  payload = data.get("inputs") or data.get("image")
96
- print(f"نوع الحمولة: {type(payload)}")
97
-
98
- if isinstance(payload, (str, bytes)):
99
- if isinstance(payload, str):
100
- print("تحويل السلسلة النصية إلى بايت")
101
- payload = payload.encode()
102
  img = self._decode_b64(payload)
103
-
104
  if img is None:
105
- print("لم يتم العثور على صورة صالحة في البيانات")
106
  return [{"label": "error", "score": 1.0}]
107
-
108
- print("معالجة الصورة باستخدام معالج ViT")
109
  inputs = self.processor(images=img, return_tensors="pt").to(self.device)
110
-
111
- print("بدء التنبؤ باستخدام النموذج")
112
  with torch.no_grad():
113
  outputs = self.model(**inputs)
114
- logits = outputs.logits
115
- probs = torch.nn.functional.softmax(logits[0], dim=0)
116
- print(f"تم الحصول على الاحتمالات: {probs}")
117
-
118
- # تحويل النتائج إلى التنسيق المطلوب: Array<label: string, score:number>
119
  results = []
120
- for i, label in enumerate(self.labels):
121
- score = float(probs[i])
122
- print(f"التسمية: {label}, الدرجة: {score}")
123
  results.append({
124
  "label": label,
125
- "score": score
126
  })
127
-
128
- # ترتيب النتائج تنازلياً حسب درجة الثقة
129
  results.sort(key=lambda x: x["score"], reverse=True)
130
- print(f"النتائج النهائية: {results}")
131
-
132
  return results
133
-
134
  except Exception as e:
135
- print(f"حدث خطأ أثناء المعالجة: {e}")
136
- print(f"نوع الخطأ: {type(e).__name__}")
137
- print(f"تفاصيل الخطأ: {str(e)}")
138
- import traceback
139
- traceback.print_exc()
140
-
141
  return [{"label": "error", "score": 1.0}]
 
2
  import io
3
  import os
4
  from typing import Dict, Any, List
 
5
  import torch
6
  from PIL import Image
7
  from transformers import ViTImageProcessor, ViTForImageClassification
 
10
  def __init__(self, model_dir: str) -> None:
11
  print(f"بدء تهيئة النموذج من المسار: {model_dir}")
12
  print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
13
+
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  print(f"استخدام الجهاز: {self.device}")
16
+
 
17
  try:
18
  print("تحميل معالج الصور ViT")
19
  self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
20
+
21
  print("تحميل نموذج ViT")
22
  self.model = ViTForImageClassification.from_pretrained(
23
  "google/vit-base-patch16-224",
 
36
  "real": 3,
37
  "other_ai": 4
38
  },
39
+ ignore_mismatched_sizes=True
40
  )
41
+
42
+ custom_weights = os.path.join(model_dir, "pytorch_model.bin")
43
+ if os.path.exists(custom_weights):
44
+ print(f"تحميل الأوزان من: {custom_weights}")
45
+ state_dict = torch.load(custom_weights, map_location="cpu")
46
+ self.model.load_state_dict(state_dict, strict=False)
47
+ print("تم تحميل الأوزان بنجاح")
48
+
49
+ self.model.to(self.device).eval()
50
+ self.labels = self.model.config.id2label
51
+
 
 
 
 
 
 
52
  except Exception as e:
53
+ print(f"خطأ أثناء تهيئة النموذج: {e}")
 
 
54
  raise
 
 
 
 
 
 
 
 
55
 
56
  def _decode_b64(self, b: bytes) -> Image.Image:
57
  try:
58
  print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت")
59
+ return Image.open(io.BytesIO(base64.b64decode(b))).convert("RGB")
 
 
60
  except Exception as e:
61
+ print(f"خطأ في فك الترميز: {e}")
62
  raise
63
 
64
  def __call__(self, data: Any) -> List[Dict[str, Any]]:
65
+ print(f"استدعاء __call__ مع نوع البيانات: {type(data)}")
 
 
66
 
67
+ img = None
68
  try:
69
  if isinstance(data, Image.Image):
 
70
  img = data
71
  elif isinstance(data, dict):
 
72
  payload = data.get("inputs") or data.get("image")
73
+ if isinstance(payload, str):
74
+ payload = payload.encode()
75
+ if isinstance(payload, bytes):
 
 
 
76
  img = self._decode_b64(payload)
77
+
78
  if img is None:
79
+ print("لم يتم العثور على صورة صالحة")
80
  return [{"label": "error", "score": 1.0}]
81
+
82
+ print("تحويل الصورة إلى مدخلات الموديل")
83
  inputs = self.processor(images=img, return_tensors="pt").to(self.device)
84
+
 
85
  with torch.no_grad():
86
  outputs = self.model(**inputs)
87
+ probs = torch.nn.functional.softmax(outputs.logits[0], dim=0)
88
+
 
 
 
89
  results = []
90
+ for i, prob in enumerate(probs):
91
+ label = self.labels[str(i)]
 
92
  results.append({
93
  "label": label,
94
+ "score": round(prob.item(), 4)
95
  })
96
+
 
97
  results.sort(key=lambda x: x["score"], reverse=True)
98
+ print(f"نتائج التصنيف: {results}")
 
99
  return results
100
+
101
  except Exception as e:
102
+ print(f"حدث استثناء: {e}")
 
 
 
 
 
103
  return [{"label": "error", "score": 1.0}]