File size: 6,969 Bytes
37e5914 3172202 37e5914 eecc81c 3172202 37e5914 3172202 37e5914 3172202 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import json
import sys
# Print debug information
print("Handler module loaded")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"Directory contents: {os.listdir('.')}")
if os.path.exists('/repository'):
print(f"Repository directory contents: {os.listdir('/repository')}")
# For debugging
class ViTForImageClassification:
@staticmethod
def from_pretrained(model_dir):
# This is a fake method to catch erroneous imports
print(f"ERROR: ViTForImageClassification.from_pretrained was called with {model_dir}")
raise ValueError("ViTForImageClassification is not the correct model for this application")
class EndpointHandler:
def __init__(self, model_dir):
"""
Initialize the model for AI image detection
"""
print(f"Initializing EndpointHandler with model_dir: {model_dir}")
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Define transforms first (in case model loading fails)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Class names
self.classes = ["Real Image", "AI-Generated Image"]
# Load model
try:
self.model = self._load_model(model_dir)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
# Create a dummy model as fallback
print("Creating a dummy model as fallback")
self.model = models.efficientnet_v2_s(pretrained=True)
self.model.classifier[-1] = nn.Linear(
self.model.classifier[-1].in_features, 2
)
self.model.eval()
def _load_model(self, model_dir):
print(f"Loading model from directory: {model_dir}")
print(f"Directory contents: {os.listdir(model_dir)}")
# Create model architecture
model = models.efficientnet_v2_s(weights=None)
# Recreate classifier exactly as in training
model.classifier = nn.Sequential(
nn.Linear(model.classifier[1].in_features, 1024),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(512, 2)
)
# Try to find model file in multiple possible locations
model_found = False
possible_paths = [
os.path.join(model_dir, "best_model_improved.pth"),
os.path.join(model_dir, "pytorch_model.bin"),
"best_model_improved.pth",
"/repository/best_model_improved.pth"
]
for model_path in possible_paths:
print(f"Trying model path: {model_path}")
if os.path.exists(model_path):
print(f"Found model at: {model_path}")
model.load_state_dict(torch.load(model_path, map_location=self.device))
model_found = True
break
if not model_found:
# Check if we need to copy the model file
if os.path.exists('best_model_improved.pth') and not os.path.exists(os.path.join(model_dir, 'best_model_improved.pth')):
import shutil
print(f"Copying model file to {model_dir}")
shutil.copy('best_model_improved.pth', os.path.join(model_dir, 'best_model_improved.pth'))
model.load_state_dict(torch.load(os.path.join(model_dir, 'best_model_improved.pth'), map_location=self.device))
model_found = True
if not model_found:
raise FileNotFoundError(f"Model file not found in any of these locations: {possible_paths}")
model.to(self.device)
model.eval()
return model
def __call__(self, data):
"""
Run prediction on the input data
"""
try:
print(f"Received prediction request with data type: {type(data)}")
# Parse request data
if isinstance(data, dict) and "inputs" in data:
# API format
input_data = data["inputs"]
print(f"Extracted input data from API format, type: {type(input_data)}")
else:
# Direct image
input_data = data
# Process image
if isinstance(input_data, str): # Base64 string
print("Processing base64 string image")
import base64
from io import BytesIO
# Decode base64 image
if ',' in input_data:
input_data = input_data.split(",", 1)[1]
image_bytes = base64.b64decode(input_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
elif hasattr(input_data, "read"): # File-like object
print("Processing file-like object image")
image = Image.open(input_data).convert("RGB")
elif isinstance(input_data, Image.Image):
print("Processing PIL Image")
image = input_data
else:
print(f"Unsupported input type: {type(input_data)}")
return {"error": f"Unsupported input type: {type(input_data)}"}
# Preprocess image
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
prediction = torch.argmax(probabilities).item()
# Format results
real_prob = probabilities[0].item() * 100
ai_prob = probabilities[1].item() * 100
# 修改这里: 返回符合 API 要求的格式 (Array<label: string, score:number>)
# 而不是返回原来的字典格式
return [
{
"label": "Real Image",
"score": float(real_prob)
},
{
"label": "AI-Generated Image",
"score": float(ai_prob)
}
]
except Exception as e:
import traceback
print(f"Error during prediction: {e}")
traceback.print_exc()
return {"error": str(e), "traceback": traceback.format_exc()}
|