Zuhyer999 commited on
Commit
37f5c64
·
verified ·
1 Parent(s): 576aa4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -20
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import torch
2
  import torch.nn as nn
3
- import torchvision.transforms as transforms
4
- from torchvision import models
5
  from PIL import Image
6
  import gradio as gr
7
 
8
- # Class names
9
  class_names = ["Normal", "Cancer", "Malignant"]
10
 
11
- # Image preprocessing
12
  transform = transforms.Compose([
13
  transforms.Resize((224, 224)),
14
  transforms.ToTensor(),
@@ -16,35 +15,29 @@ transform = transforms.Compose([
16
  std=[0.229, 0.224, 0.225])
17
  ])
18
 
19
- # Load model architecture
20
  def get_model():
21
  model = models.vgg16(pretrained=False)
22
  model.classifier[6] = nn.Linear(4096, 3) # 3 output classes
23
  return model
24
 
25
- # Load trained weights
26
  model = get_model()
27
  model.load_state_dict(torch.load("distilled_vgg16.pth", map_location=torch.device("cpu")))
28
  model.eval()
29
 
30
- # Prediction function
31
  def predict(img: Image.Image):
32
- image = transform(img).unsqueeze(0) # Add batch dimension
33
  with torch.no_grad():
34
- output = model(image)
35
- probabilities = torch.softmax(output, dim=1)[0]
36
- predicted_class = torch.argmax(probabilities).item()
37
- confidence = probabilities[predicted_class].item()
38
- return f"Likely: {class_names[predicted_class]} (Confidence: {confidence * 100:.2f}%)"
39
 
40
- # Gradio Interface
41
  interface = gr.Interface(
42
  fn=predict,
43
  inputs=gr.Image(type="pil"),
44
  outputs=gr.Textbox(label="Prediction"),
45
- title="Lung Cancer Detector (PyTorch VGG16)",
46
- description="Upload a CT scan image to predict: Normal, Cancer, or Malignant."
47
- )
48
-
49
- if __name__ == "__main__":
50
- interface.launch()
 
1
  import torch
2
  import torch.nn as nn
3
+ from torchvision import models, transforms
 
4
  from PIL import Image
5
  import gradio as gr
6
 
7
+ # Define class labels
8
  class_names = ["Normal", "Cancer", "Malignant"]
9
 
10
+ # Define preprocessing for CT scan images
11
  transform = transforms.Compose([
12
  transforms.Resize((224, 224)),
13
  transforms.ToTensor(),
 
15
  std=[0.229, 0.224, 0.225])
16
  ])
17
 
18
+ # Define and load model
19
  def get_model():
20
  model = models.vgg16(pretrained=False)
21
  model.classifier[6] = nn.Linear(4096, 3) # 3 output classes
22
  return model
23
 
 
24
  model = get_model()
25
  model.load_state_dict(torch.load("distilled_vgg16.pth", map_location=torch.device("cpu")))
26
  model.eval()
27
 
28
+ # Define prediction function
29
  def predict(img: Image.Image):
30
+ image = transform(img).unsqueeze(0)
31
  with torch.no_grad():
32
+ outputs = model(image)
33
+ probabilities = torch.softmax(outputs, dim=1)[0]
34
+ top_class = torch.argmax(probabilities).item()
35
+ confidence = probabilities[top_class].item()
36
+ return f"Likely: {class_names[top_class]} (Confidence: {confidence*100:.2f}%)"
37
 
38
+ # Gradio UI
39
  interface = gr.Interface(
40
  fn=predict,
41
  inputs=gr.Image(type="pil"),
42
  outputs=gr.Textbox(label="Prediction"),
43
+ title="Lung Cancer Classifier (PyTorch VGG1