File size: 6,969 Bytes
37e5914
3172202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37e5914
eecc81c
 
3172202
 
 
 
 
 
37e5914
3172202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37e5914
3172202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import json
import sys

# Print debug information
print("Handler module loaded")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"Directory contents: {os.listdir('.')}")
if os.path.exists('/repository'):
    print(f"Repository directory contents: {os.listdir('/repository')}")

# For debugging
class ViTForImageClassification:
    @staticmethod
    def from_pretrained(model_dir):
        # This is a fake method to catch erroneous imports
        print(f"ERROR: ViTForImageClassification.from_pretrained was called with {model_dir}")
        raise ValueError("ViTForImageClassification is not the correct model for this application")

class EndpointHandler:
    def __init__(self, model_dir):
        """
        Initialize the model for AI image detection
        """
        print(f"Initializing EndpointHandler with model_dir: {model_dir}")
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Define transforms first (in case model loading fails)
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        # Class names
        self.classes = ["Real Image", "AI-Generated Image"]
        
        # Load model
        try:
            self.model = self._load_model(model_dir)
            print("Model loaded successfully")
        except Exception as e:
            print(f"Error loading model: {e}")
            # Create a dummy model as fallback
            print("Creating a dummy model as fallback")
            self.model = models.efficientnet_v2_s(pretrained=True)
            self.model.classifier[-1] = nn.Linear(
                self.model.classifier[-1].in_features, 2
            )
            self.model.eval()
    
    def _load_model(self, model_dir):
        print(f"Loading model from directory: {model_dir}")
        print(f"Directory contents: {os.listdir(model_dir)}")
        
        # Create model architecture
        model = models.efficientnet_v2_s(weights=None)
        
        # Recreate classifier exactly as in training
        model.classifier = nn.Sequential(
            nn.Linear(model.classifier[1].in_features, 1024),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(512, 2)
        )
        
        # Try to find model file in multiple possible locations
        model_found = False
        possible_paths = [
            os.path.join(model_dir, "best_model_improved.pth"),
            os.path.join(model_dir, "pytorch_model.bin"),
            "best_model_improved.pth",
            "/repository/best_model_improved.pth"
        ]
        
        for model_path in possible_paths:
            print(f"Trying model path: {model_path}")
            if os.path.exists(model_path):
                print(f"Found model at: {model_path}")
                model.load_state_dict(torch.load(model_path, map_location=self.device))
                model_found = True
                break
        
        if not model_found:
            # Check if we need to copy the model file
            if os.path.exists('best_model_improved.pth') and not os.path.exists(os.path.join(model_dir, 'best_model_improved.pth')):
                import shutil
                print(f"Copying model file to {model_dir}")
                shutil.copy('best_model_improved.pth', os.path.join(model_dir, 'best_model_improved.pth'))
                model.load_state_dict(torch.load(os.path.join(model_dir, 'best_model_improved.pth'), map_location=self.device))
                model_found = True
        
        if not model_found:
            raise FileNotFoundError(f"Model file not found in any of these locations: {possible_paths}")
        
        model.to(self.device)
        model.eval()
        return model
    
    def __call__(self, data):
        """
        Run prediction on the input data
        """
        try:
            print(f"Received prediction request with data type: {type(data)}")
            
            # Parse request data
            if isinstance(data, dict) and "inputs" in data:
                # API format
                input_data = data["inputs"]
                print(f"Extracted input data from API format, type: {type(input_data)}")
            else:
                # Direct image
                input_data = data
            
            # Process image
            if isinstance(input_data, str):  # Base64 string
                print("Processing base64 string image")
                import base64
                from io import BytesIO
                
                # Decode base64 image
                if ',' in input_data:
                    input_data = input_data.split(",", 1)[1]
                image_bytes = base64.b64decode(input_data)
                image = Image.open(BytesIO(image_bytes)).convert("RGB")
            elif hasattr(input_data, "read"):  # File-like object
                print("Processing file-like object image")
                image = Image.open(input_data).convert("RGB")
            elif isinstance(input_data, Image.Image):
                print("Processing PIL Image")
                image = input_data
            else:
                print(f"Unsupported input type: {type(input_data)}")
                return {"error": f"Unsupported input type: {type(input_data)}"}
            
            # Preprocess image
            image_tensor = self.transform(image).unsqueeze(0).to(self.device)
            
            # Make prediction
            with torch.no_grad():
                outputs = self.model(image_tensor)
                probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
                prediction = torch.argmax(probabilities).item()
            
            # Format results
            real_prob = probabilities[0].item() * 100
            ai_prob = probabilities[1].item() * 100
            
            # 修改这里: 返回符合 API 要求的格式 (Array<label: string, score:number>)
            # 而不是返回原来的字典格式
            return [
                {
                    "label": "Real Image",
                    "score": float(real_prob)
                },
                {
                    "label": "AI-Generated Image",
                    "score": float(ai_prob)
                }
            ]
            
        except Exception as e:
            import traceback
            print(f"Error during prediction: {e}")
            traceback.print_exc()
            return {"error": str(e), "traceback": traceback.format_exc()}