ai-image-detector / handler.py
yaya36095's picture
Update handler.py
3172202 verified
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import json
import sys
# Print debug information
print("Handler module loaded")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"Directory contents: {os.listdir('.')}")
if os.path.exists('/repository'):
print(f"Repository directory contents: {os.listdir('/repository')}")
# For debugging
class ViTForImageClassification:
@staticmethod
def from_pretrained(model_dir):
# This is a fake method to catch erroneous imports
print(f"ERROR: ViTForImageClassification.from_pretrained was called with {model_dir}")
raise ValueError("ViTForImageClassification is not the correct model for this application")
class EndpointHandler:
def __init__(self, model_dir):
"""
Initialize the model for AI image detection
"""
print(f"Initializing EndpointHandler with model_dir: {model_dir}")
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Define transforms first (in case model loading fails)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Class names
self.classes = ["Real Image", "AI-Generated Image"]
# Load model
try:
self.model = self._load_model(model_dir)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
# Create a dummy model as fallback
print("Creating a dummy model as fallback")
self.model = models.efficientnet_v2_s(pretrained=True)
self.model.classifier[-1] = nn.Linear(
self.model.classifier[-1].in_features, 2
)
self.model.eval()
def _load_model(self, model_dir):
print(f"Loading model from directory: {model_dir}")
print(f"Directory contents: {os.listdir(model_dir)}")
# Create model architecture
model = models.efficientnet_v2_s(weights=None)
# Recreate classifier exactly as in training
model.classifier = nn.Sequential(
nn.Linear(model.classifier[1].in_features, 1024),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(512, 2)
)
# Try to find model file in multiple possible locations
model_found = False
possible_paths = [
os.path.join(model_dir, "best_model_improved.pth"),
os.path.join(model_dir, "pytorch_model.bin"),
"best_model_improved.pth",
"/repository/best_model_improved.pth"
]
for model_path in possible_paths:
print(f"Trying model path: {model_path}")
if os.path.exists(model_path):
print(f"Found model at: {model_path}")
model.load_state_dict(torch.load(model_path, map_location=self.device))
model_found = True
break
if not model_found:
# Check if we need to copy the model file
if os.path.exists('best_model_improved.pth') and not os.path.exists(os.path.join(model_dir, 'best_model_improved.pth')):
import shutil
print(f"Copying model file to {model_dir}")
shutil.copy('best_model_improved.pth', os.path.join(model_dir, 'best_model_improved.pth'))
model.load_state_dict(torch.load(os.path.join(model_dir, 'best_model_improved.pth'), map_location=self.device))
model_found = True
if not model_found:
raise FileNotFoundError(f"Model file not found in any of these locations: {possible_paths}")
model.to(self.device)
model.eval()
return model
def __call__(self, data):
"""
Run prediction on the input data
"""
try:
print(f"Received prediction request with data type: {type(data)}")
# Parse request data
if isinstance(data, dict) and "inputs" in data:
# API format
input_data = data["inputs"]
print(f"Extracted input data from API format, type: {type(input_data)}")
else:
# Direct image
input_data = data
# Process image
if isinstance(input_data, str): # Base64 string
print("Processing base64 string image")
import base64
from io import BytesIO
# Decode base64 image
if ',' in input_data:
input_data = input_data.split(",", 1)[1]
image_bytes = base64.b64decode(input_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
elif hasattr(input_data, "read"): # File-like object
print("Processing file-like object image")
image = Image.open(input_data).convert("RGB")
elif isinstance(input_data, Image.Image):
print("Processing PIL Image")
image = input_data
else:
print(f"Unsupported input type: {type(input_data)}")
return {"error": f"Unsupported input type: {type(input_data)}"}
# Preprocess image
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
prediction = torch.argmax(probabilities).item()
# Format results
real_prob = probabilities[0].item() * 100
ai_prob = probabilities[1].item() * 100
# 修改这里: 返回符合 API 要求的格式 (Array<label: string, score:number>)
# 而不是返回原来的字典格式
return [
{
"label": "Real Image",
"score": float(real_prob)
},
{
"label": "AI-Generated Image",
"score": float(ai_prob)
}
]
except Exception as e:
import traceback
print(f"Error during prediction: {e}")
traceback.print_exc()
return {"error": str(e), "traceback": traceback.format_exc()}