eyupipler commited on
Commit
259d628
·
verified ·
1 Parent(s): 256f777

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -39
app.py CHANGED
@@ -1,45 +1,43 @@
 
1
  import torch
2
- import torch.nn as nn
3
- from huggingface_hub import hf_hub_download
 
 
4
 
5
- class SimpleCNN(nn.Module):
6
- def __init__(self, num_classes=6):
7
- super(SimpleCNN, self).__init__()
8
- self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
9
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
10
- self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
11
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
12
- self.relu = nn.ReLU()
13
- self.dropout = nn.Dropout(0.5)
14
- self._initialize_fc(num_classes)
15
 
16
- def _initialize_fc(self, num_classes):
17
- dummy_input = torch.zeros(1, 3, 448, 448)
18
- x = self.pool(self.relu(self.conv1(dummy_input)))
19
- x = self.pool(self.relu(self.conv2(x)))
20
- x = self.pool(self.relu(self.conv3(x)))
21
- flattened_size = x.view(x.size(0), -1).shape[1]
22
- self.fc1 = nn.Linear(flattened_size, 512)
23
- self.fc2 = nn.Linear(512, num_classes)
24
 
25
- def forward(self, x):
26
- x = self.pool(self.relu(self.conv1(x)))
27
- x = self.pool(self.relu(self.conv2(x)))
28
- x = self.pool(self.relu(self.conv3(x)))
29
- x = x.view(x.size(0), -1)
30
- x = self.dropout(self.relu(self.fc1(x)))
31
- x = self.fc2(x)
32
- return x
33
 
 
 
 
 
 
 
 
 
 
34
 
35
- def load_model(device: str = 'cpu'):
36
- weights_path = hf_hub_download(
37
- repo_id="Neurazum/Vbai-DPA-2.2",
38
- filename="Vbai-DPA 2.2c.pt",
39
- repo_type="model"
40
- )
41
- model = SimpleCNN(num_classes=6).to(device)
42
- state = torch.load(weights_path, map_location=device)
43
- model.load_state_dict(state)
44
- model.eval()
45
- return model
 
1
+ import gradio as gr
2
  import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from model import load_model
6
+ import numpy as np
7
 
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model = load_model(device)
 
 
 
 
 
 
 
 
10
 
11
+ class_names = [
12
+ 'Alzheimer Disease',
13
+ 'Mild Alzheimer Risk',
14
+ 'Moderate Alzheimer Risk',
15
+ 'Very Mild Alzheimer Risk',
16
+ 'No Risk',
17
+ 'Parkinson Disease'
18
+ ]
19
 
20
+ transform = transforms.Compose([
21
+ transforms.Resize((448, 448)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
24
+ std=[0.229, 0.224, 0.225])
25
+ ])
 
 
26
 
27
+ def predict(image):
28
+ image = image.convert("RGB")
29
+ tensor = transform(image).unsqueeze(0).to(device)
30
+ with torch.no_grad():
31
+ outputs = model(tensor)
32
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0]
33
+ predicted = torch.argmax(probs).item()
34
+ confidence = probs[predicted].item() * 100
35
+ return {class_names[i]: float(probs[i]) for i in range(len(class_names))}
36
 
37
+ gr.Interface(
38
+ fn=predict,
39
+ inputs=gr.Image(type="pil"),
40
+ outputs=gr.Label(num_top_classes=3),
41
+ title="Vbai-DPA 2.2 (C Version)",
42
+ description="Upload an MRI and fMRI image to classify the risk level using the 'C' version of the Vbai-DPA 2.2 model."
43
+ ).launch()