import base64 import io import os from typing import Dict, Any, List import torch from PIL import Image from transformers import ViTImageProcessor, ViTForImageClassification class EndpointHandler: def __init__(self, model_dir: str) -> None: print(f"بدء تهيئة النموذج من المسار: {model_dir}") print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"استخدام الجهاز: {self.device}") # استخدام مكتبة transformers بدلاً من timm try: print("تحميل معالج الصور ViT") self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") print("تحميل نموذج ViT") self.model = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224", num_labels=5, id2label={ 0: "stable_diffusion", 1: "midjourney", 2: "dalle", 3: "real", 4: "other_ai" }, label2id={ "stable_diffusion": 0, "midjourney": 1, "dalle": 2, "real": 3, "other_ai": 4 } ) # محاولة تحميل الأوزان المخصصة إذا كانت موجودة pytorch_path = os.path.join(model_dir, "pytorch_model.bin") if os.path.exists(pytorch_path): print(f"محاولة تحميل الأوزان المخصصة من: {pytorch_path}") try: state_dict = torch.load(pytorch_path, map_location="cpu") # تحميل الأوزان المتوافقة فقط self.model.load_state_dict(state_dict, strict=False) print("تم تحميل الأوزان المخصصة بنجاح") except Exception as e: print(f"تحذير: فشل تحميل الأوزان المخصصة: {e}") self.model.to(self.device) self.model.eval() print("تم تهيئة النموذج بنجاح") except Exception as e: print(f"خطأ في تهيئة النموذج: {e}") import traceback traceback.print_exc() raise self.labels = [ "stable_diffusion", "midjourney", "dalle", "real", "other_ai", ] def _decode_b64(self, b: bytes) -> Image.Image: try: print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت") img = Image.open(io.BytesIO(base64.b64decode(b))) print(f"تم فك الترميز بنجاح. حجم الصورة: {img.size}, وضع الصورة: {img.mode}") return img except Exception as e: print(f"خطأ في فك ترميز base64: {e}") raise def __call__(self, data: Any) -> List[Dict[str, Any]]: print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}") img: Image.Image | None = None try: if isinstance(data, Image.Image): print("البيانات هي صورة PIL") img = data elif isinstance(data, dict): print(f"البيانات هي قاموس بالمفاتيح: {list(data.keys())}") payload = data.get("inputs") or data.get("image") print(f"نوع الحمولة: {type(payload)}") if isinstance(payload, (str, bytes)): if isinstance(payload, str): print("تحويل السلسلة النصية إلى بايت") payload = payload.encode() img = self._decode_b64(payload) if img is None: print("لم يتم العثور على صورة صالحة في البيانات") return [{"label": "error", "score": 1.0}] print("معالجة الصورة باستخدام معالج ViT") inputs = self.processor(images=img, return_tensors="pt").to(self.device) print("بدء التنبؤ باستخدام النموذج") with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits[0], dim=0) print(f"تم الحصول على الاحتمالات: {probs}") # تحويل النتائج إلى التنسيق المطلوب: Array results = [] for i, label in enumerate(self.labels): score = float(probs[i]) print(f"التسمية: {label}, الدرجة: {score}") results.append({ "label": label, "score": score }) # ترتيب النتائج تنازلياً حسب درجة الثقة results.sort(key=lambda x: x["score"], reverse=True) print(f"النتائج النهائية: {results}") return results except Exception as e: print(f"حدث خطأ أثناء المعالجة: {e}") print(f"نوع الخطأ: {type(e).__name__}") print(f"تفاصيل الخطأ: {str(e)}") import traceback traceback.print_exc() return [{"label": "error", "score": 1.0}]