|
import os |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
import json |
|
import sys |
|
|
|
|
|
print("Handler module loaded") |
|
print(f"Python version: {sys.version}") |
|
print(f"PyTorch version: {torch.__version__}") |
|
print(f"Directory contents: {os.listdir('.')}") |
|
if os.path.exists('/repository'): |
|
print(f"Repository directory contents: {os.listdir('/repository')}") |
|
|
|
|
|
class ViTForImageClassification: |
|
@staticmethod |
|
def from_pretrained(model_dir): |
|
|
|
print(f"ERROR: ViTForImageClassification.from_pretrained was called with {model_dir}") |
|
raise ValueError("ViTForImageClassification is not the correct model for this application") |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
""" |
|
Initialize the model for AI image detection |
|
""" |
|
print(f"Initializing EndpointHandler with model_dir: {model_dir}") |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {self.device}") |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
self.classes = ["Real Image", "AI-Generated Image"] |
|
|
|
|
|
try: |
|
self.model = self._load_model(model_dir) |
|
print("Model loaded successfully") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
|
|
print("Creating a dummy model as fallback") |
|
self.model = models.efficientnet_v2_s(pretrained=True) |
|
self.model.classifier[-1] = nn.Linear( |
|
self.model.classifier[-1].in_features, 2 |
|
) |
|
self.model.eval() |
|
|
|
def _load_model(self, model_dir): |
|
print(f"Loading model from directory: {model_dir}") |
|
print(f"Directory contents: {os.listdir(model_dir)}") |
|
|
|
|
|
model = models.efficientnet_v2_s(weights=None) |
|
|
|
|
|
model.classifier = nn.Sequential( |
|
nn.Linear(model.classifier[1].in_features, 1024), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(1024, 512), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(512, 2) |
|
) |
|
|
|
|
|
model_found = False |
|
possible_paths = [ |
|
os.path.join(model_dir, "best_model_improved.pth"), |
|
os.path.join(model_dir, "pytorch_model.bin"), |
|
"best_model_improved.pth", |
|
"/repository/best_model_improved.pth" |
|
] |
|
|
|
for model_path in possible_paths: |
|
print(f"Trying model path: {model_path}") |
|
if os.path.exists(model_path): |
|
print(f"Found model at: {model_path}") |
|
model.load_state_dict(torch.load(model_path, map_location=self.device)) |
|
model_found = True |
|
break |
|
|
|
if not model_found: |
|
|
|
if os.path.exists('best_model_improved.pth') and not os.path.exists(os.path.join(model_dir, 'best_model_improved.pth')): |
|
import shutil |
|
print(f"Copying model file to {model_dir}") |
|
shutil.copy('best_model_improved.pth', os.path.join(model_dir, 'best_model_improved.pth')) |
|
model.load_state_dict(torch.load(os.path.join(model_dir, 'best_model_improved.pth'), map_location=self.device)) |
|
model_found = True |
|
|
|
if not model_found: |
|
raise FileNotFoundError(f"Model file not found in any of these locations: {possible_paths}") |
|
|
|
model.to(self.device) |
|
model.eval() |
|
return model |
|
|
|
def __call__(self, data): |
|
""" |
|
Run prediction on the input data |
|
""" |
|
try: |
|
print(f"Received prediction request with data type: {type(data)}") |
|
|
|
|
|
if isinstance(data, dict) and "inputs" in data: |
|
|
|
input_data = data["inputs"] |
|
print(f"Extracted input data from API format, type: {type(input_data)}") |
|
else: |
|
|
|
input_data = data |
|
|
|
|
|
if isinstance(input_data, str): |
|
print("Processing base64 string image") |
|
import base64 |
|
from io import BytesIO |
|
|
|
|
|
if ',' in input_data: |
|
input_data = input_data.split(",", 1)[1] |
|
image_bytes = base64.b64decode(input_data) |
|
image = Image.open(BytesIO(image_bytes)).convert("RGB") |
|
elif hasattr(input_data, "read"): |
|
print("Processing file-like object image") |
|
image = Image.open(input_data).convert("RGB") |
|
elif isinstance(input_data, Image.Image): |
|
print("Processing PIL Image") |
|
image = input_data |
|
else: |
|
print(f"Unsupported input type: {type(input_data)}") |
|
return {"error": f"Unsupported input type: {type(input_data)}"} |
|
|
|
|
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(image_tensor) |
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] |
|
prediction = torch.argmax(probabilities).item() |
|
|
|
|
|
real_prob = probabilities[0].item() * 100 |
|
ai_prob = probabilities[1].item() * 100 |
|
|
|
|
|
|
|
return [ |
|
{ |
|
"label": "Real Image", |
|
"score": float(real_prob) |
|
}, |
|
{ |
|
"label": "AI-Generated Image", |
|
"score": float(ai_prob) |
|
} |
|
] |
|
|
|
except Exception as e: |
|
import traceback |
|
print(f"Error during prediction: {e}") |
|
traceback.print_exc() |
|
return {"error": str(e), "traceback": traceback.format_exc()} |
|
|