File size: 3,975 Bytes
dd32056
 
 
54240fb
00fa5d2
 
963149c
00fa5d2
8ea56b1
dd32056
bc35f37
 
b50aa1b
00fa5d2
bc35f37
b50aa1b
bc35f37
963149c
 
b50aa1b
963149c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed7f1c1
b50aa1b
963149c
b50aa1b
 
 
 
 
 
 
 
 
 
 
bc35f37
b50aa1b
bc35f37
00fa5d2
ba3f2d0
bc35f37
 
b50aa1b
bc35f37
b50aa1b
bc35f37
ba3f2d0
54240fb
b50aa1b
ba3f2d0
b50aa1b
bc35f37
 
 
 
 
b50aa1b
 
 
bc35f37
b50aa1b
bc35f37
b50aa1b
bc35f37
b50aa1b
 
963149c
b50aa1b
bc35f37
963149c
b50aa1b
 
bc35f37
b50aa1b
 
bc35f37
 
b50aa1b
bc35f37
b50aa1b
bc35f37
b50aa1b
bc35f37
b50aa1b
bc35f37
b50aa1b
9a050d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import base64
import io
import os
from typing import Dict, Any, List
import torch
from PIL import Image
from transformers import ViTImageProcessor, ViTForImageClassification

class EndpointHandler:
    def __init__(self, model_dir: str) -> None:
        print(f"بدء تهيئة النموذج من المسار: {model_dir}")
        print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"استخدام الجهاز: {self.device}")

        try:
            print("تحميل معالج الصور ViT")
            self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

            print("تحميل نموذج ViT")
            self.model = ViTForImageClassification.from_pretrained(
                "google/vit-base-patch16-224",
                num_labels=5,
                id2label={
                    0: "stable_diffusion",
                    1: "midjourney",
                    2: "dalle",
                    3: "real",
                    4: "other_ai"
                },
                label2id={
                    "stable_diffusion": 0,
                    "midjourney": 1,
                    "dalle": 2,
                    "real": 3,
                    "other_ai": 4
                },
                ignore_mismatched_sizes=True
            )

            custom_weights = os.path.join(model_dir, "pytorch_model.bin")
            if os.path.exists(custom_weights):
                print(f"تحميل الأوزان من: {custom_weights}")
                state_dict = torch.load(custom_weights, map_location="cpu")
                self.model.load_state_dict(state_dict, strict=False)
                print("تم تحميل الأوزان بنجاح")

            self.model.to(self.device).eval()
            self.labels = self.model.config.id2label

        except Exception as e:
            print(f"خطأ أثناء تهيئة النموذج: {e}")
            raise

    def _decode_b64(self, b: bytes) -> Image.Image:
        try:
            print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت")
            return Image.open(io.BytesIO(base64.b64decode(b))).convert("RGB")
        except Exception as e:
            print(f"خطأ في فك الترميز: {e}")
            raise

    def __call__(self, data: Any) -> List[Dict[str, Any]]:
        print(f"استدعاء __call__ مع نوع البيانات: {type(data)}")

        img = None
        try:
            if isinstance(data, Image.Image):
                img = data
            elif isinstance(data, dict):
                payload = data.get("inputs") or data.get("image")
                if isinstance(payload, str):
                    payload = payload.encode()
                if isinstance(payload, bytes):
                    img = self._decode_b64(payload)

            if img is None:
                print("لم يتم العثور على صورة صالحة")
                return [{"label": "error", "score": 1.0}]

            print("تحويل الصورة إلى مدخلات الموديل")
            inputs = self.processor(images=img, return_tensors="pt").to(self.device)

            with torch.no_grad():
                outputs = self.model(**inputs)
                probs = torch.nn.functional.softmax(outputs.logits[0], dim=0)

            results = []
            for i, prob in enumerate(probs):
                label = self.labels[str(i)]
                results.append({
                    "label": label,
                    "score": round(prob.item(), 4)
                })

            results.sort(key=lambda x: x["score"], reverse=True)
            print(f"نتائج التصنيف: {results}")
            return results

        except Exception as e:
            print(f"حدث استثناء: {e}")
            return [{"label": "error", "score": 1.0}]