File size: 1,767 Bytes
2b85da0
1ac09fc
2b85da0
794c694
2b85da0
057e198
18e591c
2b85da0
057e198
1ac09fc
2b85da0
057e198
1ac09fc
 
 
9a03a67
057e198
9a03a67
 
057e198
9a03a67
 
 
057e198
9a03a67
794c694
 
057e198
 
 
2b85da0
1ac09fc
 
9a03a67
057e198
 
72c8baa
057e198
72c8baa
 
057e198
 
 
 
4f7be0c
 
 
 
057e198
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import gradio as gr

# Actual class names used in your training dataset
class_names = ["Normal", "Cancer", "Malignant"]

# Load feature extractor
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# Load model architecture
model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k", num_labels=3
)

# Load state dict from file and rename keys if necessary
state_dict = torch.load("distilled_vgg16.pth", map_location="cpu")
if "classifier.0.weight" in state_dict:
    # Rename keys to match ViT classifier layer format
    state_dict["classifier.weight"] = state_dict.pop("classifier.0.weight")
    state_dict["classifier.bias"] = state_dict.pop("classifier.0.bias")

# Load weights into model
model.load_state_dict(state_dict)
model.eval()

# Define prediction function
def predict(image: Image.Image):
    inputs = extractor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=1)[0]
        top_idx = torch.argmax(probs).item()
        predicted_class = class_names[top_idx]
    return f"Prediction: {predicted_class}\nPlease consult a doctor for further diagnosis."

# Set up Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Lung CT Scan"),
    outputs=gr.Textbox(label="Prediction"),
    title="Lung Cancer Classifier using ViT",
    description="Upload a CT scan image to classify it as Normal, Cancer, or Malignant. This model is based on a Vision Transformer (ViT)."
)

if __name__ == "__main__":
    interface.launch()