Spaces:
Sleeping
Sleeping
File size: 1,767 Bytes
2b85da0 1ac09fc 2b85da0 794c694 2b85da0 057e198 18e591c 2b85da0 057e198 1ac09fc 2b85da0 057e198 1ac09fc 9a03a67 057e198 9a03a67 057e198 9a03a67 057e198 9a03a67 794c694 057e198 2b85da0 1ac09fc 9a03a67 057e198 72c8baa 057e198 72c8baa 057e198 4f7be0c 057e198 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import gradio as gr
# Actual class names used in your training dataset
class_names = ["Normal", "Cancer", "Malignant"]
# Load feature extractor
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
# Load model architecture
model = AutoModelForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k", num_labels=3
)
# Load state dict from file and rename keys if necessary
state_dict = torch.load("distilled_vgg16.pth", map_location="cpu")
if "classifier.0.weight" in state_dict:
# Rename keys to match ViT classifier layer format
state_dict["classifier.weight"] = state_dict.pop("classifier.0.weight")
state_dict["classifier.bias"] = state_dict.pop("classifier.0.bias")
# Load weights into model
model.load_state_dict(state_dict)
model.eval()
# Define prediction function
def predict(image: Image.Image):
inputs = extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)[0]
top_idx = torch.argmax(probs).item()
predicted_class = class_names[top_idx]
return f"Prediction: {predicted_class}\nPlease consult a doctor for further diagnosis."
# Set up Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Lung CT Scan"),
outputs=gr.Textbox(label="Prediction"),
title="Lung Cancer Classifier using ViT",
description="Upload a CT scan image to classify it as Normal, Cancer, or Malignant. This model is based on a Vision Transformer (ViT)."
)
if __name__ == "__main__":
interface.launch()
|