Zuhyer999 commited on
Commit
057e198
·
verified ·
1 Parent(s): 4f7be0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -3,46 +3,47 @@ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
  from PIL import Image
4
  import gradio as gr
5
 
6
- # Define class labels
7
  class_names = ["Normal", "Cancer", "Malignant"]
8
 
9
- # Load the ViT feature extractor
10
  extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
11
 
12
- # Load the pretrained ViT model
13
  model = AutoModelForImageClassification.from_pretrained(
14
  "google/vit-base-patch16-224-in21k", num_labels=3
15
  )
16
 
17
- # === Rename keys to match expected format ===
18
  state_dict = torch.load("distilled_vgg16.pth", map_location="cpu")
19
-
20
- # Fix classifier layer keys if needed
21
  if "classifier.0.weight" in state_dict:
 
22
  state_dict["classifier.weight"] = state_dict.pop("classifier.0.weight")
23
  state_dict["classifier.bias"] = state_dict.pop("classifier.0.bias")
24
 
25
- # Load weights
26
  model.load_state_dict(state_dict)
27
  model.eval()
28
 
29
- # Prediction function
30
- def predict(img: Image.Image):
31
- inputs = extractor(images=img, return_tensors="pt")
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
  probs = torch.softmax(outputs.logits, dim=1)[0]
35
  top_idx = torch.argmax(probs).item()
36
- return f"Prediction: {class_names[top_idx]}\nPlease consult a doctor for further diagnosis."
 
37
 
38
- # Gradio interface
39
  interface = gr.Interface(
40
  fn=predict,
41
- inputs=gr.Image(type="pil", label="Upload Lung CT Image"),
42
- outputs=gr.Textbox(label="Diagnosis"),
43
- title="ViT Lung Cancer Classifier",
44
- description="Upload a CT scan image. The model predicts one of: Normal, Cancer, or Malignant."
45
  )
46
 
47
  if __name__ == "__main__":
48
  interface.launch()
 
 
3
  from PIL import Image
4
  import gradio as gr
5
 
6
+ # Actual class names used in your training dataset
7
  class_names = ["Normal", "Cancer", "Malignant"]
8
 
9
+ # Load feature extractor
10
  extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
11
 
12
+ # Load model architecture
13
  model = AutoModelForImageClassification.from_pretrained(
14
  "google/vit-base-patch16-224-in21k", num_labels=3
15
  )
16
 
17
+ # Load state dict from file and rename keys if necessary
18
  state_dict = torch.load("distilled_vgg16.pth", map_location="cpu")
 
 
19
  if "classifier.0.weight" in state_dict:
20
+ # Rename keys to match ViT classifier layer format
21
  state_dict["classifier.weight"] = state_dict.pop("classifier.0.weight")
22
  state_dict["classifier.bias"] = state_dict.pop("classifier.0.bias")
23
 
24
+ # Load weights into model
25
  model.load_state_dict(state_dict)
26
  model.eval()
27
 
28
+ # Define prediction function
29
+ def predict(image: Image.Image):
30
+ inputs = extractor(images=image, return_tensors="pt")
31
  with torch.no_grad():
32
  outputs = model(**inputs)
33
  probs = torch.softmax(outputs.logits, dim=1)[0]
34
  top_idx = torch.argmax(probs).item()
35
+ predicted_class = class_names[top_idx]
36
+ return f"Prediction: {predicted_class}\nPlease consult a doctor for further diagnosis."
37
 
38
+ # Set up Gradio interface
39
  interface = gr.Interface(
40
  fn=predict,
41
+ inputs=gr.Image(type="pil", label="Upload Lung CT Scan"),
42
+ outputs=gr.Textbox(label="Prediction"),
43
+ title="Lung Cancer Classifier using ViT",
44
+ description="Upload a CT scan image to classify it as Normal, Cancer, or Malignant. This model is based on a Vision Transformer (ViT)."
45
  )
46
 
47
  if __name__ == "__main__":
48
  interface.launch()
49
+