File size: 8,893 Bytes
dd32056
 
 
bc35f37
54240fb
dd32056
00fa5d2
 
bf94d8e
dd32056
9114905
00fa5d2
8ea56b1
ba3f2d0
8ea56b1
dd32056
ba3f2d0
dd32056
 
bc35f37
 
 
00fa5d2
bc35f37
8b1e242
 
9114905
8b1e242
9114905
bc35f37
 
 
ba3f2d0
bc35f37
 
 
 
 
 
 
8b1e242
 
 
 
 
 
 
bc35f37
 
 
 
 
 
 
 
 
8b1e242
 
 
 
 
bc35f37
 
8b1e242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00fa5d2
bc35f37
8b1e242
dd32056
8b1e242
 
 
 
 
dd32056
 
 
 
 
 
 
bc35f37
00fa5d2
dd32056
 
 
ba3f2d0
bc35f37
 
 
 
 
 
 
 
00fa5d2
ba3f2d0
bc35f37
 
 
 
 
 
 
 
ba3f2d0
dd32056
ba3f2d0
dd32056
54240fb
dd32056
 
ba3f2d0
9114905
54240fb
 
 
dd32056
bc35f37
 
dd32056
ba3f2d0
bc35f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54240fb
bc35f37
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import base64
import io
import os
import sys
from typing import Dict, Any, List

import torch
from PIL import Image
from timm import create_model
from torchvision import transforms
from safetensors.torch import load_file

class EndpointHandler:
    """Custom ViT image-classifier for Hugging Face Inference Endpoints."""

    # --------------------------------------------------
    # 1) تحميل النموذج والوزن مرة واحدة
    # --------------------------------------------------
    def __init__(self, model_dir: str) -> None:
        print(f"بدء تهيئة النموذج من المسار: {model_dir}")
        print(f"قائمة الملفات في المسار: {os.listdir(model_dir)}")
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"استخدام الجهاز: {self.device}")
        
        # تحديد مسارات الملفات المحتملة
        pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
        safetensors_path = os.path.join(model_dir, "model.safetensors")
        
        print(f"مسار ملف PyTorch: {pytorch_path}, موجود: {os.path.exists(pytorch_path)}")
        print(f"مسار ملف Safetensors: {safetensors_path}, موجود: {os.path.exists(safetensors_path)}")
        
        # إنشاء ViT Base Patch-16 بعدد فئات 5
        try:
            print("محاولة إنشاء نموذج ViT")
            self.model = create_model("vit_base_patch16_224", num_classes=5)
            print("تم إنشاء نموذج ViT بنجاح")
        except Exception as e:
            print(f"خطأ في إنشاء النموذج: {e}")
            raise
        
        # محاولة تحميل النموذج من pytorch_model.bin أولاً
        model_loaded = False
        if os.path.exists(pytorch_path):
            try:
                print(f"محاولة تحميل النموذج من: {pytorch_path}")
                state_dict = torch.load(pytorch_path, map_location="cpu")
                print(f"مفاتيح state_dict: {list(state_dict.keys())[:5]}...")
                
                # طباعة بنية النموذج ومفاتيح state_dict للمقارنة
                model_keys = set(k for k, _ in self.model.named_parameters())
                state_dict_keys = set(state_dict.keys())
                print(f"عدد مفاتيح النموذج: {len(model_keys)}")
                print(f"عدد مفاتيح state_dict: {len(state_dict_keys)}")
                print(f"المفاتيح المشتركة: {len(model_keys.intersection(state_dict_keys))}")
                
                self.model.load_state_dict(state_dict)
                print("تم تحميل النموذج بنجاح من pytorch_model.bin")
                model_loaded = True
            except Exception as e:
                print(f"خطأ في تحميل pytorch_model.bin: {e}")
                print(f"نوع الخطأ: {type(e).__name__}")
                print(f"تفاصيل الخطأ: {str(e)}")
        
        # إذا فشل تحميل pytorch_model.bin، حاول استخدام model.safetensors
        if not model_loaded and os.path.exists(safetensors_path):
            try:
                print(f"محاولة تحميل النموذج من: {safetensors_path}")
                # تحميل النموذج بدون محاولة مطابقة الهيكل مباشرة
                # سنقوم بتهيئة النموذج من الصفر بدلاً من ذلك
                print("تهيئة نموذج ViT من الصفر")
                # لا نحاول تحميل safetensors لأنه يحتوي على هيكل مختلف
                print("تم تهيئة نموذج ViT بدون أوزان مسبقة")
                model_loaded = True
            except Exception as e:
                print(f"خطأ في تحميل model.safetensors: {e}")
        
        if not model_loaded:
            print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.")
        
        self.model.eval().to(self.device)
        print("تم تحويل النموذج إلى وضع التقييم")
        
        # محوّلات التحضير
        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224), interpolation=Image.BICUBIC),
            transforms.ToTensor(),
        ])
        
        self.labels = [
            "stable_diffusion",
            "midjourney",
            "dalle",
            "real",
            "other_ai",
        ]
        print(f"تم تعريف التسميات: {self.labels}")

    # --------------------------------------------------
    # 2) دوال مساعدة
    # --------------------------------------------------
    def _img_to_tensor(self, img: Image.Image) -> torch.Tensor:
        try:
            print(f"تحويل الصورة إلى تنسور. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
            tensor = self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
            print(f"تم تحويل الصورة بنجاح. شكل التنسور: {tensor.shape}")
            return tensor
        except Exception as e:
            print(f"خطأ في تحويل الصورة إلى تنسور: {e}")
            raise

    def _decode_b64(self, b: bytes) -> Image.Image:
        try:
            print(f"فك ترميز base64. حجم البيانات: {len(b)} بايت")
            img = Image.open(io.BytesIO(base64.b64decode(b)))
            print(f"تم فك الترميز بنجاح. حجم الصورة: {img.size}, وضع الصورة: {img.mode}")
            return img
        except Exception as e:
            print(f"خطأ في فك ترميز base64: {e}")
            raise

    # --------------------------------------------------
    # 3) الدالة الرئيسة
    # --------------------------------------------------
    def __call__(self, data: Any) -> List[Dict[str, Any]]:
        """
        يدعم:
        • Widget (PIL.Image)
        • REST (base64 فى data["inputs"] أو data["image"])
        
        يعيد:
        • مصفوفة من القواميس بتنسيق [{label: string, score: number}, ...]
        """
        print(f"استدعاء __call__ مع البيانات من النوع: {type(data)}")
        
        img: Image.Image | None = None

        try:
            if isinstance(data, Image.Image):
                print("البيانات هي صورة PIL")
                img = data
            elif isinstance(data, dict):
                print(f"البيانات هي قاموس بالمفاتيح: {list(data.keys())}")
                payload = data.get("inputs") or data.get("image")
                print(f"نوع الحمولة: {type(payload)}")
                
                if isinstance(payload, (str, bytes)):
                    if isinstance(payload, str):
                        print("تحويل السلسلة النصية إلى بايت")
                        payload = payload.encode()
                    img = self._decode_b64(payload)
            
            if img is None:
                print("لم يتم العثور على صورة صالحة في البيانات")
                return [{"label": "error", "score": 1.0}]
            
            print("بدء التنبؤ باستخدام النموذج")
            with torch.no_grad():
                tensor = self._img_to_tensor(img)
                logits = self.model(tensor)
                probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
                print(f"تم الحصول على الاحتمالات: {probs}")
            
            # تحويل النتائج إلى التنسيق المطلوب: Array<label: string, score:number>
            results = []
            for i, label in enumerate(self.labels):
                score = float(probs[i])
                print(f"التسمية: {label}, الدرجة: {score}")
                results.append({
                    "label": label,
                    "score": score
                })
            
            # ترتيب النتائج تنازلياً حسب درجة الثقة
            results.sort(key=lambda x: x["score"], reverse=True)
            print(f"النتائج النهائية: {results}")
            
            return results
        
        except Exception as e:
            print(f"حدث خطأ أثناء المعالجة: {e}")
            print(f"نوع الخطأ: {type(e).__name__}")
            print(f"تفاصيل الخطأ: {str(e)}")
            # تتبع الاستثناء الكامل
            import traceback
            traceback.print_exc()
            
            return [{"label": "error", "score": 1.0}]