File size: 8,893 Bytes
dd32056 bc35f37 54240fb dd32056 00fa5d2 bf94d8e dd32056 9114905 00fa5d2 8ea56b1 ba3f2d0 8ea56b1 dd32056 ba3f2d0 dd32056 bc35f37 00fa5d2 bc35f37 8b1e242 9114905 8b1e242 9114905 bc35f37 ba3f2d0 bc35f37 8b1e242 bc35f37 8b1e242 bc35f37 8b1e242 00fa5d2 bc35f37 8b1e242 dd32056 8b1e242 dd32056 bc35f37 00fa5d2 dd32056 ba3f2d0 bc35f37 00fa5d2 ba3f2d0 bc35f37 ba3f2d0 dd32056 ba3f2d0 dd32056 54240fb dd32056 ba3f2d0 9114905 54240fb dd32056 bc35f37 dd32056 ba3f2d0 bc35f37 54240fb bc35f37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import base64
import io
import os
import sys
from typing import Dict, Any, List
import torch
from PIL import Image
from timm import create_model
from torchvision import transforms
from safetensors.torch import load_file
class EndpointHandler:
"""Custom ViT image-classifier for Hugging Face Inference Endpoints."""
# --------------------------------------------------
# 1) تحميل النموذج والوزن مرة واحدة
# --------------------------------------------------
def __init__(self, model_dir: str) -> None:
print(f"بدء تهيئة النموذج من المسار: {model_dir}")
print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"استخدام الجهاز: {self.device}")
# تحديد مسارات الملفات المحتملة
pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
safetensors_path = os.path.join(model_dir, "model.safetensors")
print(f"مسار ملف PyTorch: {pytorch_path}, موجود: {os.path.exists(pytorch_path)}")
print(f"مسار ملف Safetensors: {safetensors_path}, موجود: {os.path.exists(safetensors_path)}")
# إنشاء ViT Base Patch-16 بعدد فئات 5
try:
print("محاولة إنشاء نموذج ViT")
self.model = create_model("vit_base_patch16_224", num_classes=5)
print("تم إنشاء نموذج ViT بنجاح")
except Exception as e:
print(f"خطأ في إنشاء النموذج: {e}")
raise
# محاولة تحميل النموذج من pytorch_model.bin أولاً
model_loaded = False
if os.path.exists(pytorch_path):
try:
print(f"محاولة تحميل النموذج من: {pytorch_path}")
state_dict = torch.load(pytorch_path, map_location="cpu")
print(f"مفاتيح state_dict: {list(state_dict.keys())[:5]}...")
# طباعة بنية النموذج ومفاتيح state_dict للمقارنة
model_keys = set(k for k, _ in self.model.named_parameters())
state_dict_keys = set(state_dict.keys())
print(f"عدد مفاتيح النموذج: {len(model_keys)}")
print(f"عدد مفاتيح state_dict: {len(state_dict_keys)}")
print(f"المفاتيح المشتركة: {len(model_keys.intersection(state_dict_keys))}")
self.model.load_state_dict(state_dict)
print("تم تحميل النموذج بنجاح من pytorch_model.bin")
model_loaded = True
except Exception as e:
print(f"خطأ في تحميل pytorch_model.bin: {e}")
print(f"نوع الخطأ: {type(e).__name__}")
print(f"تفاصيل الخطأ: {str(e)}")
# إذا فشل تحميل pytorch_model.bin، حاول استخدام model.safetensors
if not model_loaded and os.path.exists(safetensors_path):
try:
print(f"محاولة تحميل النموذج من: {safetensors_path}")
# تحميل النموذج بدون محاولة مطابقة الهيكل مباشرة
# سنقوم بتهيئة النموذج من الصفر بدلاً من ذلك
print("تهيئة نموذج ViT من الصفر")
# لا نحاول تحميل safetensors لأنه يحتوي على هيكل مختلف
print("تم تهيئة نموذج ViT بدون أوزان مسبقة")
model_loaded = True
except Exception as e:
print(f"خطأ في تحميل model.safetensors: {e}")
if not model_loaded:
print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.")
self.model.eval().to(self.device)
print("تم تحويل النموذج إلى وضع التقييم")
# محوّلات التحضير
self.preprocess = transforms.Compose([
transforms.Resize((224, 224), interpolation=Image.BICUBIC),
transforms.ToTensor(),
])
self.labels = [
"stable_diffusion",
"midjourney",
"dalle",
"real",
"other_ai",
]
print(f"تم تعريف التسميات: {self.labels}")
# --------------------------------------------------
# 2) دوال مساعدة
# --------------------------------------------------
def _img_to_tensor(self, img: Image.Image) -> torch.Tensor:
try:
print(f"تحويل الصورة إلى تنسور. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
tensor = self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
print(f"تم تحويل الصورة بنجاح. شكل التنسور: {tensor.shape}")
return tensor
except Exception as e:
print(f"خطأ في تحويل الصورة إلى تنسور: {e}")
raise
def _decode_b64(self, b: bytes) -> Image.Image:
try:
print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت")
img = Image.open(io.BytesIO(base64.b64decode(b)))
print(f"تم فك الترميز بنجاح. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
return img
except Exception as e:
print(f"خطأ في فك ترميز base64: {e}")
raise
# --------------------------------------------------
# 3) الدالة الرئيسة
# --------------------------------------------------
def __call__(self, data: Any) -> List[Dict[str, Any]]:
"""
يدعم:
• Widget (PIL.Image)
• REST (base64 فى data["inputs"] أو data["image"])
يعيد:
• مصفوفة من القواميس بتنسيق [{label: string, score: number}, ...]
"""
print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}")
img: Image.Image | None = None
try:
if isinstance(data, Image.Image):
print("البيانات هي صورة PIL")
img = data
elif isinstance(data, dict):
print(f"البيانات هي قاموس بالمفاتيح: {list(data.keys())}")
payload = data.get("inputs") or data.get("image")
print(f"نوع الحمولة: {type(payload)}")
if isinstance(payload, (str, bytes)):
if isinstance(payload, str):
print("تحويل السلسلة النصية إلى بايت")
payload = payload.encode()
img = self._decode_b64(payload)
if img is None:
print("لم يتم العثور على صورة صالحة في البيانات")
return [{"label": "error", "score": 1.0}]
print("بدء التنبؤ باستخدام النموذج")
with torch.no_grad():
tensor = self._img_to_tensor(img)
logits = self.model(tensor)
probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
print(f"تم الحصول على الاحتمالات: {probs}")
# تحويل النتائج إلى التنسيق المطلوب: Array<label: string, score:number>
results = []
for i, label in enumerate(self.labels):
score = float(probs[i])
print(f"التسمية: {label}, الدرجة: {score}")
results.append({
"label": label,
"score": score
})
# ترتيب النتائج تنازلياً حسب درجة الثقة
results.sort(key=lambda x: x["score"], reverse=True)
print(f"النتائج النهائية: {results}")
return results
except Exception as e:
print(f"حدث خطأ أثناء المعالجة: {e}")
print(f"نوع الخطأ: {type(e).__name__}")
print(f"تفاصيل الخطأ: {str(e)}")
# تتبع الاستثناء الكامل
import traceback
traceback.print_exc()
return [{"label": "error", "score": 1.0}]
|