ayushsinha commited on
Commit
26ef1c2
Β·
verified Β·
1 Parent(s): 21588a3

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +54 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.models as models
4
+ from huggingface_hub import hf_hub_download
5
+ import json
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+
9
+ # Set device
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # Load model weights and class labels
13
+ weights_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-cataract-detection-system", filename="cataract_detection_resnet18_quantized.pth")
14
+ labels_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-cataract-detection-system", filename="class_names.json")
15
+
16
+ with open(labels_path, "r") as f:
17
+ class_labels = json.load(f)
18
+
19
+ # Load model
20
+ model = models.resnet18(pretrained=False)
21
+ num_classes = len(class_labels)
22
+ model.fc = torch.nn.Linear(in_features=512, out_features=num_classes)
23
+ model.load_state_dict(torch.load(weights_path, map_location=device))
24
+ model.to(device)
25
+ model.eval()
26
+
27
+ # Define transform
28
+ transform = transforms.Compose([
29
+ transforms.Resize((224, 224)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
32
+ ])
33
+
34
+ # Prediction function
35
+ def predict(image):
36
+ image = transform(image).unsqueeze(0).to(device) # Preprocess image
37
+ with torch.no_grad():
38
+ output = model(image)
39
+ _, predicted = torch.max(output, 1)
40
+ predicted_class = class_labels[predicted.item()]
41
+ return {predicted_class: 1.0} # Confidence is assumed to be 1.0 for simplicity
42
+
43
+ # Gradio Interface
44
+ demo = gr.Interface(
45
+ fn=predict,
46
+ inputs=gr.Image(type="pil", label="Upload an Eye Image"),
47
+ outputs=gr.Label(label="Cataract Detection Result"),
48
+ title="πŸ‘οΈ Cataract Detection System πŸ₯",
49
+ description="πŸ”¬ Upload an eye image, and the AI model will determine if cataract is present! πŸ‘“",
50
+ theme="huggingface",
51
+ allow_flagging="never",
52
+ )
53
+
54
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ sentencepiece
5
+ torchvision
6
+ huggingface_hub
7
+ pillow
8
+ numpy