File size: 3,975 Bytes
dd32056 54240fb 00fa5d2 963149c 00fa5d2 8ea56b1 dd32056 bc35f37 b50aa1b 00fa5d2 bc35f37 b50aa1b bc35f37 963149c b50aa1b 963149c ed7f1c1 b50aa1b 963149c b50aa1b bc35f37 b50aa1b bc35f37 00fa5d2 ba3f2d0 bc35f37 b50aa1b bc35f37 b50aa1b bc35f37 ba3f2d0 54240fb b50aa1b ba3f2d0 b50aa1b bc35f37 b50aa1b bc35f37 b50aa1b bc35f37 b50aa1b bc35f37 b50aa1b 963149c b50aa1b bc35f37 963149c b50aa1b bc35f37 b50aa1b bc35f37 b50aa1b bc35f37 b50aa1b bc35f37 b50aa1b bc35f37 b50aa1b bc35f37 b50aa1b 9a050d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import base64
import io
import os
from typing import Dict, Any, List
import torch
from PIL import Image
from transformers import ViTImageProcessor, ViTForImageClassification
class EndpointHandler:
def __init__(self, model_dir: str) -> None:
print(f"بدء تهيئة النموذج من المسار: {model_dir}")
print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"استخدام الجهاز: {self.device}")
try:
print("تحميل معالج الصور ViT")
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
print("تحميل نموذج ViT")
self.model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
num_labels=5,
id2label={
0: "stable_diffusion",
1: "midjourney",
2: "dalle",
3: "real",
4: "other_ai"
},
label2id={
"stable_diffusion": 0,
"midjourney": 1,
"dalle": 2,
"real": 3,
"other_ai": 4
},
ignore_mismatched_sizes=True
)
custom_weights = os.path.join(model_dir, "pytorch_model.bin")
if os.path.exists(custom_weights):
print(f"تحميل الأوزان من: {custom_weights}")
state_dict = torch.load(custom_weights, map_location="cpu")
self.model.load_state_dict(state_dict, strict=False)
print("تم تحميل الأوزان بنجاح")
self.model.to(self.device).eval()
self.labels = self.model.config.id2label
except Exception as e:
print(f"خطأ أثناء تهيئة النموذج: {e}")
raise
def _decode_b64(self, b: bytes) -> Image.Image:
try:
print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت")
return Image.open(io.BytesIO(base64.b64decode(b))).convert("RGB")
except Exception as e:
print(f"خطأ في فك الترميز: {e}")
raise
def __call__(self, data: Any) -> List[Dict[str, Any]]:
print(f"استدعاء __call__ مع نوع البيانات: {type(data)}")
img = None
try:
if isinstance(data, Image.Image):
img = data
elif isinstance(data, dict):
payload = data.get("inputs") or data.get("image")
if isinstance(payload, str):
payload = payload.encode()
if isinstance(payload, bytes):
img = self._decode_b64(payload)
if img is None:
print("لم يتم العثور على صورة صالحة")
return [{"label": "error", "score": 1.0}]
print("تحويل الصورة إلى مدخلات الموديل")
inputs = self.processor(images=img, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits[0], dim=0)
results = []
for i, prob in enumerate(probs):
label = self.labels[str(i)]
results.append({
"label": label,
"score": round(prob.item(), 4)
})
results.sort(key=lambda x: x["score"], reverse=True)
print(f"نتائج التصنيف: {results}")
return results
except Exception as e:
print(f"حدث استثناء: {e}")
return [{"label": "error", "score": 1.0}]
|