|
import base64 |
|
import io |
|
import os |
|
import sys |
|
from typing import Dict, Any, List |
|
|
|
import torch |
|
from PIL import Image |
|
from timm import create_model |
|
from torchvision import transforms |
|
from safetensors.torch import load_file |
|
|
|
class EndpointHandler: |
|
"""Custom ViT image-classifier for Hugging Face Inference Endpoints.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, model_dir: str) -> None: |
|
print(f"بدء تهيئة النموذج من المسار: {model_dir}") |
|
print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}") |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"استخدام الجهاز: {self.device}") |
|
|
|
|
|
pytorch_path = os.path.join(model_dir, "pytorch_model.bin") |
|
safetensors_path = os.path.join(model_dir, "model.safetensors") |
|
|
|
print(f"مسار ملف PyTorch: {pytorch_path}, موجود: {os.path.exists(pytorch_path)}") |
|
print(f"مسار ملف Safetensors: {safetensors_path}, موجود: {os.path.exists(safetensors_path)}") |
|
|
|
|
|
try: |
|
print("محاولة إنشاء نموذج ViT") |
|
self.model = create_model("vit_base_patch16_224", num_classes=5) |
|
print("تم إنشاء نموذج ViT بنجاح") |
|
except Exception as e: |
|
print(f"خطأ في إنشاء النموذج: {e}") |
|
raise |
|
|
|
|
|
model_loaded = False |
|
if os.path.exists(pytorch_path): |
|
try: |
|
print(f"محاولة تحميل النموذج من: {pytorch_path}") |
|
state_dict = torch.load(pytorch_path, map_location="cpu") |
|
print(f"مفاتيح state_dict: {list(state_dict.keys())[:5]}...") |
|
|
|
|
|
model_keys = set(k for k, _ in self.model.named_parameters()) |
|
state_dict_keys = set(state_dict.keys()) |
|
print(f"عدد مفاتيح النموذج: {len(model_keys)}") |
|
print(f"عدد مفاتيح state_dict: {len(state_dict_keys)}") |
|
print(f"المفاتيح المشتركة: {len(model_keys.intersection(state_dict_keys))}") |
|
|
|
self.model.load_state_dict(state_dict) |
|
print("تم تحميل النموذج بنجاح من pytorch_model.bin") |
|
model_loaded = True |
|
except Exception as e: |
|
print(f"خطأ في تحميل pytorch_model.bin: {e}") |
|
print(f"نوع الخطأ: {type(e).__name__}") |
|
print(f"تفاصيل الخطأ: {str(e)}") |
|
|
|
|
|
if not model_loaded and os.path.exists(safetensors_path): |
|
try: |
|
print(f"محاولة تحميل النموذج من: {safetensors_path}") |
|
|
|
|
|
print("تهيئة نموذج ViT من الصفر") |
|
|
|
print("تم تهيئة نموذج ViT بدون أوزان مسبقة") |
|
model_loaded = True |
|
except Exception as e: |
|
print(f"خطأ في تحميل model.safetensors: {e}") |
|
|
|
if not model_loaded: |
|
print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.") |
|
|
|
self.model.eval().to(self.device) |
|
print("تم تحويل النموذج إلى وضع التقييم") |
|
|
|
|
|
self.preprocess = transforms.Compose([ |
|
transforms.Resize((224, 224), interpolation=Image.BICUBIC), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
self.labels = [ |
|
"stable_diffusion", |
|
"midjourney", |
|
"dalle", |
|
"real", |
|
"other_ai", |
|
] |
|
print(f"تم تعريف التسميات: {self.labels}") |
|
|
|
|
|
|
|
|
|
def _img_to_tensor(self, img: Image.Image) -> torch.Tensor: |
|
try: |
|
print(f"تحويل الصورة إلى تنسور. حجم الصورة: {img.size}, وضع الصورة: {img.mode}") |
|
tensor = self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device) |
|
print(f"تم تحويل الصورة بنجاح. شكل التنسور: {tensor.shape}") |
|
return tensor |
|
except Exception as e: |
|
print(f"خطأ في تحويل الصورة إلى تنسور: {e}") |
|
raise |
|
|
|
def _decode_b64(self, b: bytes) -> Image.Image: |
|
try: |
|
print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت") |
|
img = Image.open(io.BytesIO(base64.b64decode(b))) |
|
print(f"تم فك الترميز بنجاح. حجم الصورة: {img.size}, وضع الصورة: {img.mode}") |
|
return img |
|
except Exception as e: |
|
print(f"خطأ في فك ترميز base64: {e}") |
|
raise |
|
|
|
|
|
|
|
|
|
def __call__(self, data: Any) -> List[Dict[str, Any]]: |
|
""" |
|
يدعم: |
|
• Widget (PIL.Image) |
|
• REST (base64 فى data["inputs"] أو data["image"]) |
|
|
|
يعيد: |
|
• مصفوفة من القواميس بتنسيق [{label: string, score: number}, ...] |
|
""" |
|
print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}") |
|
|
|
img: Image.Image | None = None |
|
|
|
try: |
|
if isinstance(data, Image.Image): |
|
print("البيانات هي صورة PIL") |
|
img = data |
|
elif isinstance(data, dict): |
|
print(f"البيانات هي قاموس بالمفاتيح: {list(data.keys())}") |
|
payload = data.get("inputs") or data.get("image") |
|
print(f"نوع الحمولة: {type(payload)}") |
|
|
|
if isinstance(payload, (str, bytes)): |
|
if isinstance(payload, str): |
|
print("تحويل السلسلة النصية إلى بايت") |
|
payload = payload.encode() |
|
img = self._decode_b64(payload) |
|
|
|
if img is None: |
|
print("لم يتم العثور على صورة صالحة في البيانات") |
|
return [{"label": "error", "score": 1.0}] |
|
|
|
print("بدء التنبؤ باستخدام النموذج") |
|
with torch.no_grad(): |
|
tensor = self._img_to_tensor(img) |
|
logits = self.model(tensor) |
|
probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0) |
|
print(f"تم الحصول على الاحتمالات: {probs}") |
|
|
|
|
|
results = [] |
|
for i, label in enumerate(self.labels): |
|
score = float(probs[i]) |
|
print(f"التسمية: {label}, الدرجة: {score}") |
|
results.append({ |
|
"label": label, |
|
"score": score |
|
}) |
|
|
|
|
|
results.sort(key=lambda x: x["score"], reverse=True) |
|
print(f"النتائج النهائية: {results}") |
|
|
|
return results |
|
|
|
except Exception as e: |
|
print(f"حدث خطأ أثناء المعالجة: {e}") |
|
print(f"نوع الخطأ: {type(e).__name__}") |
|
print(f"تفاصيل الخطأ: {str(e)}") |
|
|
|
import traceback |
|
traceback.print_exc() |
|
|
|
return [{"label": "error", "score": 1.0}] |
|
|