import torch from transformers import AutoModelForImageClassification, AutoFeatureExtractor from PIL import Image import gradio as gr # Actual class names used in your training dataset class_names = ["Normal", "Cancer", "Malignant"] # Load feature extractor extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") # Load model architecture model = AutoModelForImageClassification.from_pretrained( "google/vit-base-patch16-224-in21k", num_labels=3 ) # Load state dict from file and rename keys if necessary state_dict = torch.load("distilled_vgg16.pth", map_location="cpu") if "classifier.0.weight" in state_dict: # Rename keys to match ViT classifier layer format state_dict["classifier.weight"] = state_dict.pop("classifier.0.weight") state_dict["classifier.bias"] = state_dict.pop("classifier.0.bias") # Load weights into model model.load_state_dict(state_dict) model.eval() # Define prediction function def predict(image: Image.Image): inputs = extractor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1)[0] top_idx = torch.argmax(probs).item() predicted_class = class_names[top_idx] return f"Prediction: {predicted_class}\nPlease consult a doctor for further diagnosis." # Set up Gradio interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Lung CT Scan"), outputs=gr.Textbox(label="Prediction"), title="Lung Cancer Classifier using ViT", description="Upload a CT scan image to classify it as Normal, Cancer, or Malignant. This model is based on a Vision Transformer (ViT)." ) if __name__ == "__main__": interface.launch()