Zuhyer999 commited on
Commit
3c1db92
·
verified ·
1 Parent(s): 72c8baa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -5,10 +5,10 @@ from torchvision import models
5
  from PIL import Image
6
  import gradio as gr
7
 
8
- # Class labels (must match model output)
9
  class_names = ["Normal", "Cancer", "Malignant"]
10
 
11
- # Define transform for CT scan input
12
  transform = transforms.Compose([
13
  transforms.Resize((224, 224)),
14
  transforms.ToTensor(),
@@ -16,7 +16,7 @@ transform = transforms.Compose([
16
  std=[0.229, 0.224, 0.225])
17
  ])
18
 
19
- # Load the VGG16 model structure and adjust classifier
20
  def get_model():
21
  model = models.vgg16(pretrained=False)
22
  model.classifier[6] = nn.Linear(4096, 3) # 3 output classes
@@ -27,22 +27,23 @@ model = get_model()
27
  model.load_state_dict(torch.load("distilled_vgg16.pth", map_location=torch.device("cpu")))
28
  model.eval()
29
 
30
- # Prediction function
31
  def predict(img: Image.Image):
32
- image = transform(img).unsqueeze(0) # Add batch dimension
33
  with torch.no_grad():
34
  output = model(image)
35
- probs = torch.softmax(output, dim=1)[0] # Convert logits to probabilities
36
- return {class_names[i]: float(probs[i]) for i in range(len(class_names))}
37
 
38
- # Gradio UI
39
  interface = gr.Interface(
40
  fn=predict,
41
  inputs=gr.Image(type="pil"),
42
- outputs=gr.Label(num_top_classes=3),
43
  title="Lung Cancer Classifier (PyTorch)",
44
- description="Upload a CT scan image to classify it as Normal, Cancer, or Malignant."
45
  )
46
 
47
  if __name__ == "__main__":
48
  interface.launch()
 
 
5
  from PIL import Image
6
  import gradio as gr
7
 
8
+ # Define class names
9
  class_names = ["Normal", "Cancer", "Malignant"]
10
 
11
+ # Image transform
12
  transform = transforms.Compose([
13
  transforms.Resize((224, 224)),
14
  transforms.ToTensor(),
 
16
  std=[0.229, 0.224, 0.225])
17
  ])
18
 
19
+ # Load VGG16 model and modify final layer
20
  def get_model():
21
  model = models.vgg16(pretrained=False)
22
  model.classifier[6] = nn.Linear(4096, 3) # 3 output classes
 
27
  model.load_state_dict(torch.load("distilled_vgg16.pth", map_location=torch.device("cpu")))
28
  model.eval()
29
 
30
+ # Prediction function — return only top class
31
  def predict(img: Image.Image):
32
+ image = transform(img).unsqueeze(0)
33
  with torch.no_grad():
34
  output = model(image)
35
+ predicted_class = torch.argmax(output, dim=1).item()
36
+ return f"Likely: {class_names[predicted_class]}"
37
 
38
+ # Gradio Interface
39
  interface = gr.Interface(
40
  fn=predict,
41
  inputs=gr.Image(type="pil"),
42
+ outputs=gr.Textbox(label="Prediction"),
43
  title="Lung Cancer Classifier (PyTorch)",
44
+ description="Upload a lung CT scan. The model predicts the most likely condition: Normal, Cancer, or Malignant."
45
  )
46
 
47
  if __name__ == "__main__":
48
  interface.launch()
49
+