devadethanr commited on
Commit
d8df7f2
·
verified ·
1 Parent(s): 9d28cf0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -18
app.py CHANGED
@@ -1,17 +1,13 @@
1
- # import gradio as gr
2
-
3
- # gr.load("models/devadethanr/alz_model").launch()
4
-
5
-
6
  import gradio as gr
7
- from transformers import AutoModelForImageClassification
8
  import torch
9
  import numpy as np
 
10
 
11
-
12
- # Load the model and image processor from the Hub
13
- model_name = "devadethanr/alz_model"
14
  model = AutoModelForImageClassification.from_pretrained(model_name)
 
15
 
16
  # Get the label names from the model's configuration
17
  labels = model.config.id2label
@@ -27,28 +23,31 @@ def predict_image(image):
27
  Returns:
28
  The predicted label with its corresponding probability.
29
  """
 
 
 
30
 
31
- image = model.preprocess_image(image, return_tensors="pt").to(model.device)
 
32
  with torch.no_grad():
33
- logits = model(**image).logits
34
 
35
  predicted_label_id = logits.argmax(-1).item()
36
- predicted_label = labels[predicted_label_id]
37
 
38
  # Calculate probabilities using softmax
39
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
40
- confidences = {label: float(probabilities[0][i]) for i, label in enumerate(labels)}
41
 
42
  return predicted_label, confidences
43
 
44
-
45
- # Create the Gradio interface (same as before)
46
  iface = gr.Interface(
47
  fn=predict_image,
48
- inputs=gr.inputs.Image(type="pil", label="Upload MRI Image"),
49
  outputs=[
50
- gr.outputs.Label(label="Prediction"),
51
- gr.outputs.JSON(label="Confidence Scores")
52
  ],
53
  title="Alzheimer's Disease MRI Image Classifier",
54
  description="Upload an MRI image to predict the stage of Alzheimer's disease."
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
3
  import torch
4
  import numpy as np
5
+ from PIL import Image
6
 
7
+ # Load the model from the Hub
8
+ model_name = "devadethanr/alz_model"
 
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
+ processor = AutoImageProcessor.from_pretrained(model_name)
11
 
12
  # Get the label names from the model's configuration
13
  labels = model.config.id2label
 
23
  Returns:
24
  The predicted label with its corresponding probability.
25
  """
26
+ # Preprocessing steps:
27
+ image = np.array(image)
28
+ image = np.repeat(image[:, :, np.newaxis], 3, axis=2) # Convert grayscale to RGB
29
 
30
+ # Model inference:
31
+ inputs = processor(images=image, return_tensors="pt").to(model.device)
32
  with torch.no_grad():
33
+ logits = model(**inputs).logits
34
 
35
  predicted_label_id = logits.argmax(-1).item()
36
+ predicted_label = labels[str(predicted_label_id)]
37
 
38
  # Calculate probabilities using softmax
39
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
40
+ confidences = {labels[str(i)]: float(probabilities[0][i]) for i in range(len(labels))}
41
 
42
  return predicted_label, confidences
43
 
44
+ # Create the Gradio interface (updated)
 
45
  iface = gr.Interface(
46
  fn=predict_image,
47
+ inputs=gr.Image(type="pil", label="Upload MRI Image"), # Use gr.Image directly
48
  outputs=[
49
+ gr.Label(label="Prediction"),
50
+ gr.JSON(label="Confidence Scores")
51
  ],
52
  title="Alzheimer's Disease MRI Image Classifier",
53
  description="Upload an MRI image to predict the stage of Alzheimer's disease."