Zuhyer999's picture
Update app.py
057e198 verified
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import gradio as gr
# Actual class names used in your training dataset
class_names = ["Normal", "Cancer", "Malignant"]
# Load feature extractor
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
# Load model architecture
model = AutoModelForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k", num_labels=3
)
# Load state dict from file and rename keys if necessary
state_dict = torch.load("distilled_vgg16.pth", map_location="cpu")
if "classifier.0.weight" in state_dict:
# Rename keys to match ViT classifier layer format
state_dict["classifier.weight"] = state_dict.pop("classifier.0.weight")
state_dict["classifier.bias"] = state_dict.pop("classifier.0.bias")
# Load weights into model
model.load_state_dict(state_dict)
model.eval()
# Define prediction function
def predict(image: Image.Image):
inputs = extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)[0]
top_idx = torch.argmax(probs).item()
predicted_class = class_names[top_idx]
return f"Prediction: {predicted_class}\nPlease consult a doctor for further diagnosis."
# Set up Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Lung CT Scan"),
outputs=gr.Textbox(label="Prediction"),
title="Lung Cancer Classifier using ViT",
description="Upload a CT scan image to classify it as Normal, Cancer, or Malignant. This model is based on a Vision Transformer (ViT)."
)
if __name__ == "__main__":
interface.launch()