ai-source-detector / handler.py
yaya36095's picture
Update handler.py
b50aa1b verified
raw
history blame
3.98 kB
import base64
import io
import os
from typing import Dict, Any, List
import torch
from PIL import Image
from transformers import ViTImageProcessor, ViTForImageClassification
class EndpointHandler:
def __init__(self, model_dir: str) -> None:
print(f"بدء تهيئة النموذج من المسار: {model_dir}")
print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"استخدام الجهاز: {self.device}")
try:
print("تحميل معالج الصور ViT")
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
print("تحميل نموذج ViT")
self.model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
num_labels=5,
id2label={
0: "stable_diffusion",
1: "midjourney",
2: "dalle",
3: "real",
4: "other_ai"
},
label2id={
"stable_diffusion": 0,
"midjourney": 1,
"dalle": 2,
"real": 3,
"other_ai": 4
},
ignore_mismatched_sizes=True
)
custom_weights = os.path.join(model_dir, "pytorch_model.bin")
if os.path.exists(custom_weights):
print(f"تحميل الأوزان من: {custom_weights}")
state_dict = torch.load(custom_weights, map_location="cpu")
self.model.load_state_dict(state_dict, strict=False)
print("تم تحميل الأوزان بنجاح")
self.model.to(self.device).eval()
self.labels = self.model.config.id2label
except Exception as e:
print(f"خطأ أثناء تهيئة النموذج: {e}")
raise
def _decode_b64(self, b: bytes) -> Image.Image:
try:
print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت")
return Image.open(io.BytesIO(base64.b64decode(b))).convert("RGB")
except Exception as e:
print(f"خطأ في فك الترميز: {e}")
raise
def __call__(self, data: Any) -> List[Dict[str, Any]]:
print(f"استدعاء __call__ مع نوع البيانات: {type(data)}")
img = None
try:
if isinstance(data, Image.Image):
img = data
elif isinstance(data, dict):
payload = data.get("inputs") or data.get("image")
if isinstance(payload, str):
payload = payload.encode()
if isinstance(payload, bytes):
img = self._decode_b64(payload)
if img is None:
print("لم يتم العثور على صورة صالحة")
return [{"label": "error", "score": 1.0}]
print("تحويل الصورة إلى مدخلات الموديل")
inputs = self.processor(images=img, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits[0], dim=0)
results = []
for i, prob in enumerate(probs):
label = self.labels[str(i)]
results.append({
"label": label,
"score": round(prob.item(), 4)
})
results.sort(key=lambda x: x["score"], reverse=True)
print(f"نتائج التصنيف: {results}")
return results
except Exception as e:
print(f"حدث استثناء: {e}")
return [{"label": "error", "score": 1.0}]