|
import base64 |
|
import io |
|
import os |
|
from typing import Dict, Any, List |
|
|
|
import torch |
|
from PIL import Image |
|
from transformers import ViTImageProcessor, ViTForImageClassification |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str) -> None: |
|
print(f"بدء تهيئة النموذج من المسار: {model_dir}") |
|
print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}") |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"استخدام الجهاز: {self.device}") |
|
|
|
|
|
try: |
|
print("تحميل معالج الصور ViT") |
|
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") |
|
|
|
print("تحميل نموذج ViT") |
|
self.model = ViTForImageClassification.from_pretrained( |
|
"google/vit-base-patch16-224", |
|
num_labels=5, |
|
id2label={ |
|
0: "stable_diffusion", |
|
1: "midjourney", |
|
2: "dalle", |
|
3: "real", |
|
4: "other_ai" |
|
}, |
|
label2id={ |
|
"stable_diffusion": 0, |
|
"midjourney": 1, |
|
"dalle": 2, |
|
"real": 3, |
|
"other_ai": 4 |
|
} |
|
) |
|
|
|
|
|
pytorch_path = os.path.join(model_dir, "pytorch_model.bin") |
|
if os.path.exists(pytorch_path): |
|
print(f"محاولة تحميل الأوزان المخصصة من: {pytorch_path}") |
|
try: |
|
state_dict = torch.load(pytorch_path, map_location="cpu") |
|
|
|
self.model.load_state_dict(state_dict, strict=False) |
|
print("تم تحميل الأوزان المخصصة بنجاح") |
|
except Exception as e: |
|
print(f"تحذير: فشل تحميل الأوزان المخصصة: {e}") |
|
|
|
self.model.to(self.device) |
|
self.model.eval() |
|
print("تم تهيئة النموذج بنجاح") |
|
|
|
except Exception as e: |
|
print(f"خطأ في تهيئة النموذج: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
raise |
|
|
|
self.labels = [ |
|
"stable_diffusion", |
|
"midjourney", |
|
"dalle", |
|
"real", |
|
"other_ai", |
|
] |
|
|
|
def _decode_b64(self, b: bytes) -> Image.Image: |
|
try: |
|
print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت") |
|
img = Image.open(io.BytesIO(base64.b64decode(b))) |
|
print(f"تم فك الترميز بنجاح. حجم الصورة: {img.size}, وضع الصورة: {img.mode}") |
|
return img |
|
except Exception as e: |
|
print(f"خطأ في فك ترميز base64: {e}") |
|
raise |
|
|
|
def __call__(self, data: Any) -> List[Dict[str, Any]]: |
|
print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}") |
|
|
|
img: Image.Image | None = None |
|
|
|
try: |
|
if isinstance(data, Image.Image): |
|
print("البيانات هي صورة PIL") |
|
img = data |
|
elif isinstance(data, dict): |
|
print(f"البيانات هي قاموس بالمفاتيح: {list(data.keys())}") |
|
payload = data.get("inputs") or data.get("image") |
|
print(f"نوع الحمولة: {type(payload)}") |
|
|
|
if isinstance(payload, (str, bytes)): |
|
if isinstance(payload, str): |
|
print("تحويل السلسلة النصية إلى بايت") |
|
payload = payload.encode() |
|
img = self._decode_b64(payload) |
|
|
|
if img is None: |
|
print("لم يتم العثور على صورة صالحة في البيانات") |
|
return [{"label": "error", "score": 1.0}] |
|
|
|
print("معالجة الصورة باستخدام معالج ViT") |
|
inputs = self.processor(images=img, return_tensors="pt").to(self.device) |
|
|
|
print("بدء التنبؤ باستخدام النموذج") |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
logits = outputs.logits |
|
probs = torch.nn.functional.softmax(logits[0], dim=0) |
|
print(f"تم الحصول على الاحتمالات: {probs}") |
|
|
|
|
|
results = [] |
|
for i, label in enumerate(self.labels): |
|
score = float(probs[i]) |
|
print(f"التسمية: {label}, الدرجة: {score}") |
|
results.append({ |
|
"label": label, |
|
"score": score |
|
}) |
|
|
|
|
|
results.sort(key=lambda x: x["score"], reverse=True) |
|
print(f"النتائج النهائية: {results}") |
|
|
|
return results |
|
|
|
except Exception as e: |
|
print(f"حدث خطأ أثناء المعالجة: {e}") |
|
print(f"نوع الخطأ: {type(e).__name__}") |
|
print(f"تفاصيل الخطأ: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
return [{"label": "error", "score": 1.0}] |
|
|