Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoModelForImageClassification, AutoFeatureExtractor | |
from PIL import Image | |
import gradio as gr | |
# Actual class names used in your training dataset | |
class_names = ["Normal", "Cancer", "Malignant"] | |
# Load feature extractor | |
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") | |
# Load model architecture | |
model = AutoModelForImageClassification.from_pretrained( | |
"google/vit-base-patch16-224-in21k", num_labels=3 | |
) | |
# Load state dict from file and rename keys if necessary | |
state_dict = torch.load("distilled_vgg16.pth", map_location="cpu") | |
if "classifier.0.weight" in state_dict: | |
# Rename keys to match ViT classifier layer format | |
state_dict["classifier.weight"] = state_dict.pop("classifier.0.weight") | |
state_dict["classifier.bias"] = state_dict.pop("classifier.0.bias") | |
# Load weights into model | |
model.load_state_dict(state_dict) | |
model.eval() | |
# Define prediction function | |
def predict(image: Image.Image): | |
inputs = extractor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.softmax(outputs.logits, dim=1)[0] | |
top_idx = torch.argmax(probs).item() | |
predicted_class = class_names[top_idx] | |
return f"Prediction: {predicted_class}\nPlease consult a doctor for further diagnosis." | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil", label="Upload Lung CT Scan"), | |
outputs=gr.Textbox(label="Prediction"), | |
title="Lung Cancer Classifier using ViT", | |
description="Upload a CT scan image to classify it as Normal, Cancer, or Malignant. This model is based on a Vision Transformer (ViT)." | |
) | |
if __name__ == "__main__": | |
interface.launch() | |