devadethanr commited on
Commit
4ecb9f3
·
verified ·
1 Parent(s): 2a631c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -12,6 +12,8 @@ processor = AutoImageProcessor.from_pretrained(model_name)
12
  # Get the label names from the model's configuration
13
  labels = model.config.id2label
14
 
 
 
15
  # Define the prediction function (with preprocessing)
16
  def predict_image(image):
17
  """
@@ -25,32 +27,40 @@ def predict_image(image):
25
  """
26
  # Preprocessing steps:
27
  image = np.array(image)
28
- image = np.repeat(image[:, :, np.newaxis], 3, axis=2) # Convert grayscale to RGB
 
 
 
 
 
29
 
30
  # Model inference:
31
  inputs = processor(images=image, return_tensors="pt").to(model.device)
32
  with torch.no_grad():
33
  logits = model(**inputs).logits
34
 
 
 
 
35
  predicted_label_id = logits.argmax(-1).item()
36
- predicted_label = labels[str(predicted_label_id)]
37
 
38
  # Calculate probabilities using softmax
39
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
40
- confidences = {labels[str(i)]: float(probabilities[0][i]) for i in range(len(labels))}
41
 
42
  return predicted_label, confidences
43
 
44
- # Create the Gradio interface (updated)
45
  iface = gr.Interface(
46
  fn=predict_image,
47
- inputs=gr.Image(type="pil", label="Upload MRI Image"), # Use gr.Image directly
48
  outputs=[
49
  gr.Label(label="Prediction"),
50
  gr.JSON(label="Confidence Scores")
51
  ],
52
  title="Alzheimer's Disease MRI Image Classifier",
53
- description="Upload a MRI image to predict the stage of Alzheimer's disease."
54
  )
55
 
56
  iface.launch()
 
12
  # Get the label names from the model's configuration
13
  labels = model.config.id2label
14
 
15
+ print("Labels:", labels) # Debugging statement to check the labels
16
+
17
  # Define the prediction function (with preprocessing)
18
  def predict_image(image):
19
  """
 
27
  """
28
  # Preprocessing steps:
29
  image = np.array(image)
30
+ if image.ndim == 2: # Convert grayscale to RGB if needed
31
+ image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
32
+
33
+ # Resize image if necessary (optional step)
34
+ if image.shape[0] != 224 or image.shape[1] != 224:
35
+ image = np.array(Image.fromarray(image).resize((224, 224)))
36
 
37
  # Model inference:
38
  inputs = processor(images=image, return_tensors="pt").to(model.device)
39
  with torch.no_grad():
40
  logits = model(**inputs).logits
41
 
42
+ print(f"logits shape: {logits.shape}") # Debugging statement to check shape
43
+ print(f"logits: {logits}") # Debugging statement to check content
44
+
45
  predicted_label_id = logits.argmax(-1).item()
46
+ predicted_label = labels[predicted_label_id]
47
 
48
  # Calculate probabilities using softmax
49
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
50
+ confidences = {labels[i]: float(probabilities[0][i]) for i in range(len(labels))}
51
 
52
  return predicted_label, confidences
53
 
54
+ # Create the Gradio interface
55
  iface = gr.Interface(
56
  fn=predict_image,
57
+ inputs=gr.Image(type="pil", label="Upload MRI Image"),
58
  outputs=[
59
  gr.Label(label="Prediction"),
60
  gr.JSON(label="Confidence Scores")
61
  ],
62
  title="Alzheimer's Disease MRI Image Classifier",
63
+ description="Upload an MRI image to predict the stage of Alzheimer's disease."
64
  )
65
 
66
  iface.launch()