alexakup05 commited on
Commit
e5bb95a
·
1 Parent(s): f6aa039

Add application file

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """демо.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1NK3gtM_1xpqJt79c_lDgu45FY4aMd3kr
8
+ """
9
+
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ # Загрузка файла конфигурации и модели
13
+ config_path = hf_hub_download(repo_id="alexakup05/eye_disease_classifier", filename="config.json")
14
+ model_path = hf_hub_download(repo_id="alexakup05/eye_disease_classifier", filename="model1.pth")
15
+
16
+ print(f"Модель и конфигурация загружены: {config_path}, {model_path}")
17
+
18
+ import json
19
+
20
+ # Загружаем конфигурацию
21
+ with open(config_path, 'r') as f:
22
+ config = json.load(f)
23
+
24
+ print(config) # Проверим содержимое конфигурации
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from torchvision import models
29
+
30
+ class EyeDiseaseEfficientNet(nn.Module):
31
+ def __init__(self, config):
32
+ super(EyeDiseaseEfficientNet, self).__init__()
33
+ self.efficientnet = models.efficientnet_b4(pretrained=False)
34
+ self.efficientnet.classifier = nn.Identity()
35
+ self.fc_age_sex = nn.Sequential(
36
+ nn.Linear(2, 64),
37
+ nn.ReLU(),
38
+ nn.Dropout(0.5)
39
+ )
40
+ self.fc_combined = nn.Sequential(
41
+ nn.Linear(1792 + 64, 512),
42
+ nn.ReLU(),
43
+ nn.Dropout(0.6),
44
+ nn.Linear(512, 8)
45
+ )
46
+
47
+ def forward(self, x_img, x_age_sex):
48
+ x_img = self.efficientnet(x_img)
49
+ x_age_sex = self.fc_age_sex(x_age_sex)
50
+ x = torch.cat((x_img, x_age_sex), dim=1)
51
+ x = self.fc_combined(x)
52
+ return x
53
+
54
+ model = EyeDiseaseEfficientNet(config)
55
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
56
+
57
+ device = torch.device("cpu")
58
+ model = model.to(device)
59
+ model.eval()
60
+
61
+ input_image = torch.randn(1, 3, 224, 224).to(device)
62
+ input_age_sex = torch.tensor([[45, 1]], dtype=torch.float32).to(device)
63
+
64
+ with torch.no_grad():
65
+ output = model(input_image, input_age_sex)
66
+ print(output)
67
+
68
+ import torch.nn.functional as F
69
+
70
+ logits = torch.tensor([[-2.6384, -1.8599, 0.0206, 2.0523, 0.2476, 1.9363, 1.5297, -1.0108]], device='cpu')
71
+ probabilities = F.softmax(logits, dim=1)
72
+ predicted_class = torch.argmax(probabilities, dim=1)
73
+ print(f"Предсказанный класс: {predicted_class.item()}")
74
+
75
+ import gradio as gr
76
+ import cv2
77
+ import numpy as np
78
+ from PIL import Image
79
+
80
+ def detect_eye(img):
81
+ eye_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_eye.xml')
82
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
83
+ eyes = eye_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
84
+ if len(eyes) > 0:
85
+ (x, y, w, h) = eyes[0]
86
+ img = img[y:y+h, x:x+w]
87
+ return img
88
+
89
+ def preprocess_image(img):
90
+ img = cv2.medianBlur(img, 3)
91
+ lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
92
+ l, a, b = cv2.split(lab)
93
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
94
+ l = clahe.apply(l)
95
+ lab = cv2.merge((l, a, b))
96
+ img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
97
+ return img
98
+
99
+ def resize_with_padding(img, target_size=(224, 224)):
100
+ h, w = img.shape[:2]
101
+ scale = min(target_size[0] / h, target_size[1] / w)
102
+ new_w, new_h = int(w * scale), int(h * scale)
103
+ resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
104
+ pad_w = (target_size[1] - new_w) // 2
105
+ pad_h = (target_size[0] - new_h) // 2
106
+ padded_img = cv2.copyMakeBorder(
107
+ resized_img, pad_h, target_size[0] - new_h - pad_h, pad_w, target_size[1] - new_w - pad_w,
108
+ cv2.BORDER_CONSTANT, value=[0, 0, 0]
109
+ )
110
+ return padded_img
111
+
112
+ def predict(age, sex, img):
113
+ img = detect_eye(img)
114
+ img = preprocess_image(img)
115
+ img = resize_with_padding(img)
116
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
117
+ img = Image.fromarray(img)
118
+ img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float()
119
+ age_sex_tensor = torch.tensor([[age, 0 if sex == "Male" else 1]]).float()
120
+ with torch.no_grad():
121
+ outputs = model(img_tensor, age_sex_tensor)
122
+ probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0]
123
+ disease_labels = [
124
+ "Normal",
125
+ "Diabetic Retinopathy",
126
+ "Glaucoma",
127
+ "Cataract",
128
+ "Age-related Macular Degeneration",
129
+ "Hypertension",
130
+ "Pathological Myopia",
131
+ "Other Diseases/Abnormalities"
132
+ ]
133
+
134
+ result = {disease_labels[i]: f"{probabilities[i]*100:.2f}%" for i in range(len(disease_labels))}
135
+ return result, img
136
+
137
+ examples = [
138
+ [30, "Male", "myopia.png"]
139
+ ]
140
+
141
+ iface = gr.Interface(
142
+ fn=predict,
143
+ inputs=[
144
+ gr.Slider(minimum=0, maximum=100, step=1, label="Age"),
145
+ gr.Radio(["Male", "Female"], label="Gener"),
146
+ gr.Image(type="numpy", label="Upload Eye Image/ your Selfies / photo")
147
+ ],
148
+ outputs=[gr.JSON(label="Predictions"), gr.Image(label="Processed Image")],
149
+ examples=examples
150
+ )
151
+
152
+ iface.launch()