ai-source-detector / handler.py
yaya36095's picture
Update handler.py
bc35f37 verified
raw
history blame
8.89 kB
import base64
import io
import os
import sys
from typing import Dict, Any, List
import torch
from PIL import Image
from timm import create_model
from torchvision import transforms
from safetensors.torch import load_file
class EndpointHandler:
"""Custom ViT image-classifier for Hugging Face Inference Endpoints."""
# --------------------------------------------------
# 1) تحميل النموذج والوزن مرة واحدة
# --------------------------------------------------
def __init__(self, model_dir: str) -> None:
print(f"بدء تهيئة النموذج من المسار: {model_dir}")
print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"استخدام الجهاز: {self.device}")
# تحديد مسارات الملفات المحتملة
pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
safetensors_path = os.path.join(model_dir, "model.safetensors")
print(f"مسار ملف PyTorch: {pytorch_path}, موجود: {os.path.exists(pytorch_path)}")
print(f"مسار ملف Safetensors: {safetensors_path}, موجود: {os.path.exists(safetensors_path)}")
# إنشاء ViT Base Patch-16 بعدد فئات 5
try:
print("محاولة إنشاء نموذج ViT")
self.model = create_model("vit_base_patch16_224", num_classes=5)
print("تم إنشاء نموذج ViT بنجاح")
except Exception as e:
print(f"خطأ في إنشاء النموذج: {e}")
raise
# محاولة تحميل النموذج من pytorch_model.bin أولاً
model_loaded = False
if os.path.exists(pytorch_path):
try:
print(f"محاولة تحميل النموذج من: {pytorch_path}")
state_dict = torch.load(pytorch_path, map_location="cpu")
print(f"مفاتيح state_dict: {list(state_dict.keys())[:5]}...")
# طباعة بنية النموذج ومفاتيح state_dict للمقارنة
model_keys = set(k for k, _ in self.model.named_parameters())
state_dict_keys = set(state_dict.keys())
print(f"عدد مفاتيح النموذج: {len(model_keys)}")
print(f"عدد مفاتيح state_dict: {len(state_dict_keys)}")
print(f"المفاتيح المشتركة: {len(model_keys.intersection(state_dict_keys))}")
self.model.load_state_dict(state_dict)
print("تم تحميل النموذج بنجاح من pytorch_model.bin")
model_loaded = True
except Exception as e:
print(f"خطأ في تحميل pytorch_model.bin: {e}")
print(f"نوع الخطأ: {type(e).__name__}")
print(f"تفاصيل الخطأ: {str(e)}")
# إذا فشل تحميل pytorch_model.bin، حاول استخدام model.safetensors
if not model_loaded and os.path.exists(safetensors_path):
try:
print(f"محاولة تحميل النموذج من: {safetensors_path}")
# تحميل النموذج بدون محاولة مطابقة الهيكل مباشرة
# سنقوم بتهيئة النموذج من الصفر بدلاً من ذلك
print("تهيئة نموذج ViT من الصفر")
# لا نحاول تحميل safetensors لأنه يحتوي على هيكل مختلف
print("تم تهيئة نموذج ViT بدون أوزان مسبقة")
model_loaded = True
except Exception as e:
print(f"خطأ في تحميل model.safetensors: {e}")
if not model_loaded:
print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.")
self.model.eval().to(self.device)
print("تم تحويل النموذج إلى وضع التقييم")
# محوّلات التحضير
self.preprocess = transforms.Compose([
transforms.Resize((224, 224), interpolation=Image.BICUBIC),
transforms.ToTensor(),
])
self.labels = [
"stable_diffusion",
"midjourney",
"dalle",
"real",
"other_ai",
]
print(f"تم تعريف التسميات: {self.labels}")
# --------------------------------------------------
# 2) دوال مساعدة
# --------------------------------------------------
def _img_to_tensor(self, img: Image.Image) -> torch.Tensor:
try:
print(f"تحويل الصورة إلى تنسور. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
tensor = self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
print(f"تم تحويل الصورة بنجاح. شكل التنسور: {tensor.shape}")
return tensor
except Exception as e:
print(f"خطأ في تحويل الصورة إلى تنسور: {e}")
raise
def _decode_b64(self, b: bytes) -> Image.Image:
try:
print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت")
img = Image.open(io.BytesIO(base64.b64decode(b)))
print(f"تم فك الترميز بنجاح. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
return img
except Exception as e:
print(f"خطأ في فك ترميز base64: {e}")
raise
# --------------------------------------------------
# 3) الدالة الرئيسة
# --------------------------------------------------
def __call__(self, data: Any) -> List[Dict[str, Any]]:
"""
يدعم:
• Widget (PIL.Image)
• REST (base64 فى data["inputs"] أو data["image"])
يعيد:
• مصفوفة من القواميس بتنسيق [{label: string, score: number}, ...]
"""
print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}")
img: Image.Image | None = None
try:
if isinstance(data, Image.Image):
print("البيانات هي صورة PIL")
img = data
elif isinstance(data, dict):
print(f"البيانات هي قاموس بالمفاتيح: {list(data.keys())}")
payload = data.get("inputs") or data.get("image")
print(f"نوع الحمولة: {type(payload)}")
if isinstance(payload, (str, bytes)):
if isinstance(payload, str):
print("تحويل السلسلة النصية إلى بايت")
payload = payload.encode()
img = self._decode_b64(payload)
if img is None:
print("لم يتم العثور على صورة صالحة في البيانات")
return [{"label": "error", "score": 1.0}]
print("بدء التنبؤ باستخدام النموذج")
with torch.no_grad():
tensor = self._img_to_tensor(img)
logits = self.model(tensor)
probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
print(f"تم الحصول على الاحتمالات: {probs}")
# تحويل النتائج إلى التنسيق المطلوب: Array<label: string, score:number>
results = []
for i, label in enumerate(self.labels):
score = float(probs[i])
print(f"التسمية: {label}, الدرجة: {score}")
results.append({
"label": label,
"score": score
})
# ترتيب النتائج تنازلياً حسب درجة الثقة
results.sort(key=lambda x: x["score"], reverse=True)
print(f"النتائج النهائية: {results}")
return results
except Exception as e:
print(f"حدث خطأ أثناء المعالجة: {e}")
print(f"نوع الخطأ: {type(e).__name__}")
print(f"تفاصيل الخطأ: {str(e)}")
# تتبع الاستثناء الكامل
import traceback
traceback.print_exc()
return [{"label": "error", "score": 1.0}]