Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""демо.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1NK3gtM_1xpqJt79c_lDgu45FY4aMd3kr | |
""" | |
from huggingface_hub import hf_hub_download | |
# Загрузка файла конфигурации и модели | |
config_path = hf_hub_download(repo_id="alexakup05/eye_disease_classifier", filename="config.json") | |
model_path = hf_hub_download(repo_id="alexakup05/eye_disease_classifier", filename="model1.pth") | |
print(f"Модель и конфигурация загружены: {config_path}, {model_path}") | |
import json | |
# Загружаем конфигурацию | |
with open(config_path, 'r') as f: | |
config = json.load(f) | |
print(config) # Проверим содержимое конфигурации | |
import torch | |
import torch.nn as nn | |
from torchvision import models | |
class EyeDiseaseEfficientNet(nn.Module): | |
def __init__(self, config): | |
super(EyeDiseaseEfficientNet, self).__init__() | |
self.efficientnet = models.efficientnet_b4(pretrained=False) | |
self.efficientnet.classifier = nn.Identity() | |
self.fc_age_sex = nn.Sequential( | |
nn.Linear(2, 64), | |
nn.ReLU(), | |
nn.Dropout(0.5) | |
) | |
self.fc_combined = nn.Sequential( | |
nn.Linear(1792 + 64, 512), | |
nn.ReLU(), | |
nn.Dropout(0.6), | |
nn.Linear(512, 8) | |
) | |
def forward(self, x_img, x_age_sex): | |
x_img = self.efficientnet(x_img) | |
x_age_sex = self.fc_age_sex(x_age_sex) | |
x = torch.cat((x_img, x_age_sex), dim=1) | |
x = self.fc_combined(x) | |
return x | |
model = EyeDiseaseEfficientNet(config) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
device = torch.device("cpu") | |
model = model.to(device) | |
model.eval() | |
input_image = torch.randn(1, 3, 224, 224).to(device) | |
input_age_sex = torch.tensor([[45, 1]], dtype=torch.float32).to(device) | |
with torch.no_grad(): | |
output = model(input_image, input_age_sex) | |
print(output) | |
import torch.nn.functional as F | |
logits = torch.tensor([[-2.6384, -1.8599, 0.0206, 2.0523, 0.2476, 1.9363, 1.5297, -1.0108]], device='cpu') | |
probabilities = F.softmax(logits, dim=1) | |
predicted_class = torch.argmax(probabilities, dim=1) | |
print(f"Предсказанный класс: {predicted_class.item()}") | |
import gradio as gr | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
def detect_eye(img): | |
eye_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_eye.xml') | |
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
eyes = eye_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) | |
if len(eyes) > 0: | |
(x, y, w, h) = eyes[0] | |
img = img[y:y+h, x:x+w] | |
return img | |
def preprocess_image(img): | |
img = cv2.medianBlur(img, 3) | |
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) | |
l, a, b = cv2.split(lab) | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
l = clahe.apply(l) | |
lab = cv2.merge((l, a, b)) | |
img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) | |
return img | |
def resize_with_padding(img, target_size=(224, 224)): | |
h, w = img.shape[:2] | |
scale = min(target_size[0] / h, target_size[1] / w) | |
new_w, new_h = int(w * scale), int(h * scale) | |
resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
pad_w = (target_size[1] - new_w) // 2 | |
pad_h = (target_size[0] - new_h) // 2 | |
padded_img = cv2.copyMakeBorder( | |
resized_img, pad_h, target_size[0] - new_h - pad_h, pad_w, target_size[1] - new_w - pad_w, | |
cv2.BORDER_CONSTANT, value=[0, 0, 0] | |
) | |
return padded_img | |
def predict(age, sex, img): | |
img = detect_eye(img) | |
img = preprocess_image(img) | |
img = resize_with_padding(img) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
img = Image.fromarray(img) | |
img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float() | |
age_sex_tensor = torch.tensor([[age, 0 if sex == "Male" else 1]]).float() | |
with torch.no_grad(): | |
outputs = model(img_tensor, age_sex_tensor) | |
probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] | |
disease_labels = [ | |
"Normal", | |
"Diabetic Retinopathy", | |
"Glaucoma", | |
"Cataract", | |
"Age-related Macular Degeneration", | |
"Hypertension", | |
"Pathological Myopia", | |
"Other Diseases/Abnormalities" | |
] | |
result = {disease_labels[i]: f"{probabilities[i]*100:.2f}%" for i in range(len(disease_labels))} | |
return result, img | |
examples = [ | |
[30, "Male", "myopia.png"] | |
] | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Slider(minimum=0, maximum=100, step=1, label="Age"), | |
gr.Radio(["Male", "Female"], label="Gener"), | |
gr.Image(type="numpy", label="Upload Eye Image/ your Selfies / photo") | |
], | |
outputs=[gr.JSON(label="Predictions"), gr.Image(label="Processed Image")], | |
examples=examples | |
) | |
iface.launch() | |