Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- env.env +1 -0
- model.py +80 -0
- telegram-bot.py +229 -0
env.env
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
API_ENDPOINT= https://226454676c90.ngrok-free.app/api/process-annotation
|
model.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models as models
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from PIL import Image
|
6 |
+
import time
|
7 |
+
|
8 |
+
class AE(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.encoder = nn.Sequential(
|
12 |
+
nn.Linear(2048, 512), nn.ReLU(),
|
13 |
+
nn.Linear(512, 128)
|
14 |
+
)
|
15 |
+
self.decoder = nn.Sequential(
|
16 |
+
nn.Linear(128, 512), nn.ReLU(),
|
17 |
+
nn.Linear(512, 2048)
|
18 |
+
)
|
19 |
+
def forward(self, x):
|
20 |
+
return self.decoder(self.encoder(x))
|
21 |
+
|
22 |
+
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
23 |
+
resnet.fc = nn.Identity()
|
24 |
+
resnet.eval()
|
25 |
+
|
26 |
+
autoencoder = AE()
|
27 |
+
|
28 |
+
def model_option(view, category):
|
29 |
+
if view == "crl":
|
30 |
+
if category == "abdomen":
|
31 |
+
autoencoder.load_state_dict(torch.load('models/abdomen_autoencoder-0.0058.pth'))
|
32 |
+
elif category == "body":
|
33 |
+
autoencoder.load_state_dict(torch.load('models/body_autoencoder-0.0060.pth'))
|
34 |
+
elif category == "diencephalon":
|
35 |
+
autoencoder.load_state_dict(torch.load('models/diencephalon_autoencoder-0.0050.pth'))
|
36 |
+
elif category == "gsac":
|
37 |
+
autoencoder.load_state_dict(torch.load('models/gestation_sac_autoencoder-0.0044.pth'))
|
38 |
+
elif category == "head":
|
39 |
+
autoencoder.load_state_dict(torch.load('models/head_autoencoder-0.0077.pth'))
|
40 |
+
elif category == "lv":
|
41 |
+
autoencoder.load_state_dict(torch.load('models/lateral_ventricle_autoencoder-0.0045.pth'))
|
42 |
+
elif category == "mx":
|
43 |
+
autoencoder.load_state_dict(torch.load('models/maxilla_autoencoder-0.0054.pth'))
|
44 |
+
elif category == "mds":
|
45 |
+
autoencoder.load_state_dict(torch.load('models/mds_mandible_autoencoder-0.0039.pth'))
|
46 |
+
elif category == "mls":
|
47 |
+
autoencoder.load_state_dict(torch.load('models/mls_mandible_ventricle_autoencoder-0.0047.pth'))
|
48 |
+
elif category == "nb":
|
49 |
+
autoencoder.load_state_dict(torch.load('models/nasal_bone_autoencoder-0.0026.pth'))
|
50 |
+
elif category == "ntaps":
|
51 |
+
autoencoder.load_state_dict(torch.load('models/ntaps_autoencoder-0.0032.pth'))
|
52 |
+
elif category == "rbp":
|
53 |
+
autoencoder.load_state_dict(torch.load('models/rhombencephalon_autoencoder-0.0044.pth'))
|
54 |
+
elif category == "thorax":
|
55 |
+
autoencoder.load_state_dict(torch.load('models/thorax_autoencoder-0.0058.pth'))
|
56 |
+
#elif view == "nt":
|
57 |
+
autoencoder.eval()
|
58 |
+
|
59 |
+
transform = transforms.Compose([
|
60 |
+
transforms.Resize((224, 224)),
|
61 |
+
transforms.ToTensor(),
|
62 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
63 |
+
[0.229, 0.224, 0.225])
|
64 |
+
])
|
65 |
+
|
66 |
+
def predict(cropped, view, category):
|
67 |
+
model_option(view, category)
|
68 |
+
img = cropped.convert("RGB")
|
69 |
+
img_tensor = transform(img).unsqueeze(0)
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
feat = resnet(img_tensor).squeeze().numpy()
|
73 |
+
|
74 |
+
input_tensor = torch.tensor(feat).float().unsqueeze(0)
|
75 |
+
with torch.no_grad():
|
76 |
+
recon = autoencoder(input_tensor)
|
77 |
+
|
78 |
+
error = nn.functional.mse_loss(recon, input_tensor).item()
|
79 |
+
|
80 |
+
return error
|
telegram-bot.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import aiohttp
|
3 |
+
import os,time
|
4 |
+
|
5 |
+
from aiohttp import FormData
|
6 |
+
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
7 |
+
|
8 |
+
from telegram.ext import (
|
9 |
+
Application, CommandHandler, MessageHandler,
|
10 |
+
filters, CallbackQueryHandler, ContextTypes
|
11 |
+
)
|
12 |
+
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
load_dotenv(dotenv_path='env.env')
|
15 |
+
|
16 |
+
|
17 |
+
#config
|
18 |
+
TOKEN = "8320924107:AAH505mhHkOxeY3aLk0GObIpO_KCtY9hhLM"
|
19 |
+
API_ENDPOINT = os.getenv("API_ENDPOINT")
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
|
22 |
+
#category mappings
|
23 |
+
VIEW_CATEGORIES = {
|
24 |
+
"crl": {
|
25 |
+
"Maxilla": "mx",
|
26 |
+
"Mandible-MDS": "mds",
|
27 |
+
"Mandible-MLS": "mls",
|
28 |
+
"Lateral ventricle": "lv",
|
29 |
+
"Head": "head",
|
30 |
+
"Gestational sac": "gsac",
|
31 |
+
"Thorax": "thorax",
|
32 |
+
"Abdomen": "ab",
|
33 |
+
"Body(Biparietal diameter)": "bd",
|
34 |
+
"Rhombencephalon": "rbp",
|
35 |
+
"Diencephalon": "dp",
|
36 |
+
"NTAPS": "ntaps",
|
37 |
+
"Nasal bone": "nb"
|
38 |
+
},
|
39 |
+
"nt": {
|
40 |
+
"Maxilla": "mx",
|
41 |
+
"Mandible-MDS": "mds",
|
42 |
+
"Mandible-MLS": "mls",
|
43 |
+
"Lateral ventricle": "lv",
|
44 |
+
"Head": "head",
|
45 |
+
"Thorax": "thorax",
|
46 |
+
"Abdomen": "ab",
|
47 |
+
"Rhombencephalon": "rbp",
|
48 |
+
"Diencephalon": "dp",
|
49 |
+
"Nuchal translucency": "nt",
|
50 |
+
"NTAPS": "ntaps",
|
51 |
+
"Nasal bone": "nb"
|
52 |
+
}
|
53 |
+
}
|
54 |
+
|
55 |
+
#start
|
56 |
+
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
57 |
+
await update.message.reply_text("Send an ultrasound image (JPG/png only) to begin. See /instructions first")
|
58 |
+
|
59 |
+
async def instructions(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
60 |
+
instruction_message = """
|
61 |
+
π **Instructions**
|
62 |
+
|
63 |
+
πΈ **Step 1:** Send a cropped ultrasound image of a structure you want to analyze. See all structures with /list command
|
64 |
+
πΈ **Step 2:** Select the ultrasound view (CRL or NT)
|
65 |
+
πΈ **Step 3:** Choose the anatomical category to analyze
|
66 |
+
πΈ **Step 4:** Wait for the AI analysis results
|
67 |
+
|
68 |
+
π **Important notes:**
|
69 |
+
β’ Only JPG/PNG images are supported
|
70 |
+
β’ Ensure the ultrasound image is clear and properly oriented
|
71 |
+
β’ Results are for reference only - always consult a medical professional
|
72 |
+
β’ Processing may take a few seconds
|
73 |
+
β’ Stop bot with /stop
|
74 |
+
|
75 |
+
π‘ **Tips:**
|
76 |
+
β’ Use high-quality, well-lit images for better accuracy
|
77 |
+
β’ Make sure the anatomical structure is clearly visible
|
78 |
+
β’ Different views (CRL/NT) have different category options
|
79 |
+
|
80 |
+
π **Need help?** Contact support if you encounter any issues @d3ikshr.
|
81 |
+
"""
|
82 |
+
|
83 |
+
await update.message.reply_text(instruction_message, parse_mode='Markdown')
|
84 |
+
|
85 |
+
async def list_categories(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
86 |
+
list_message = """
|
87 |
+
π **Available Categories by View**
|
88 |
+
|
89 |
+
π **CRL view categories:**
|
90 |
+
β’ Maxilla β’ Mandible-MDS β’ Mandible-MLS
|
91 |
+
β’ Lateral ventricle β’ Head β’ Gestational sac
|
92 |
+
β’ Thorax β’ Abdomen β’ Body(Biparietal diameter)
|
93 |
+
β’ Rhombencephalon β’ Diencephalon β’ NTAPS
|
94 |
+
β’ Nasal bone
|
95 |
+
|
96 |
+
π **NT view categories:**
|
97 |
+
β’ Maxilla β’ Mandible-MDS β’ Mandible-MLS
|
98 |
+
β’ Lateral ventricle β’ Head β’ Thorax
|
99 |
+
β’ Abdomen β’ Rhombencephalon β’ Diencephalon
|
100 |
+
β’ Nuchal translucency β’ NTAPS β’ Nasal bone
|
101 |
+
|
102 |
+
π‘ **Note:** Categories will be shown automatically based on your selected view during analysis.
|
103 |
+
"""
|
104 |
+
|
105 |
+
await update.message.reply_text(list_message, parse_mode='Markdown')
|
106 |
+
|
107 |
+
async def stop(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
108 |
+
await update.message.reply_text("π Bot stopped for this chat. Use /start to begin again.")
|
109 |
+
# Clear user data
|
110 |
+
context.user_data.clear()
|
111 |
+
|
112 |
+
# Receive image
|
113 |
+
async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
114 |
+
photo = update.message.photo[-1]
|
115 |
+
file = await photo.get_file()
|
116 |
+
file_path = f"{update.message.from_user.id}_ultrasound.jpg"
|
117 |
+
await file.download_to_drive(file_path)
|
118 |
+
context.user_data["image_path"] = file_path
|
119 |
+
|
120 |
+
# get view
|
121 |
+
buttons = [
|
122 |
+
[InlineKeyboardButton("CRL", callback_data="view:crl"),
|
123 |
+
InlineKeyboardButton("NT", callback_data="view:nt")]
|
124 |
+
]
|
125 |
+
await update.message.reply_text(
|
126 |
+
"Select the ultrasound view:",
|
127 |
+
reply_markup=InlineKeyboardMarkup(buttons)
|
128 |
+
)
|
129 |
+
|
130 |
+
#selcet view
|
131 |
+
async def handle_view(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
132 |
+
query = update.callback_query
|
133 |
+
await query.answer()
|
134 |
+
view = query.data.split(":")[1]
|
135 |
+
context.user_data["selected_view"] = view
|
136 |
+
|
137 |
+
#get categories for the selected view
|
138 |
+
categories = list(VIEW_CATEGORIES[view].keys())
|
139 |
+
buttons = [[InlineKeyboardButton(cat, callback_data=f"category:{cat}")]
|
140 |
+
for cat in categories]
|
141 |
+
await query.edit_message_text(
|
142 |
+
"Select anatomical category:",
|
143 |
+
reply_markup=InlineKeyboardMarkup(buttons)
|
144 |
+
)
|
145 |
+
|
146 |
+
#get category then upload
|
147 |
+
async def handle_category(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
148 |
+
query = update.callback_query
|
149 |
+
await query.answer()
|
150 |
+
category_display = query.data.split(":")[1]
|
151 |
+
context.user_data["selected_category"] = category_display
|
152 |
+
|
153 |
+
image_path = context.user_data.get("image_path")
|
154 |
+
view = context.user_data.get("selected_view")
|
155 |
+
|
156 |
+
if not image_path or not view:
|
157 |
+
await query.edit_message_text("Missing image or view.")
|
158 |
+
return
|
159 |
+
|
160 |
+
await query.edit_message_text("π Processing image...")
|
161 |
+
time.sleep(3)
|
162 |
+
await query.edit_message_text("π₯ Building diagnosis, please wait...")
|
163 |
+
|
164 |
+
try:
|
165 |
+
#read the image file into memory first
|
166 |
+
with open(image_path, "rb") as f:
|
167 |
+
image_data = f.read()
|
168 |
+
|
169 |
+
#mapping
|
170 |
+
category_value = VIEW_CATEGORIES[view][category_display]
|
171 |
+
|
172 |
+
#create form data
|
173 |
+
form = FormData()
|
174 |
+
form.add_field("view", view)
|
175 |
+
form.add_field("category", category_value)
|
176 |
+
form.add_field("source", "telegram")
|
177 |
+
form.add_field("image", image_data, filename="image.jpg", content_type="image/jpeg")
|
178 |
+
|
179 |
+
#send request
|
180 |
+
async with aiohttp.ClientSession() as session:
|
181 |
+
async with session.post(API_ENDPOINT, data=form) as resp:
|
182 |
+
if resp.status == 200:
|
183 |
+
result = await resp.json()
|
184 |
+
|
185 |
+
message = f"""
|
186 |
+
π **Analysis Results**
|
187 |
+
|
188 |
+
π **View:** {result.get('view', view).upper()}
|
189 |
+
π₯ **Category:** {category_display}
|
190 |
+
|
191 |
+
π **Confidence:** {result.get('confidence', 0):.2f}%
|
192 |
+
β οΈ **Reconstruction error:** {result.get('error', 0):.5f}
|
193 |
+
|
194 |
+
π **Status:** {result.get('comment', 'No comment')}
|
195 |
+
|
196 |
+
π©Ί **Diagnosis:** {result.get('diagnosis', 'No diagnosis')}*
|
197 |
+
"""
|
198 |
+
|
199 |
+
await query.edit_message_text(message, parse_mode='Markdown')
|
200 |
+
|
201 |
+
else:
|
202 |
+
error_text = await resp.text()
|
203 |
+
await query.edit_message_text(f"β Upload failed. Status: {resp.status}\nError: {error_text}")
|
204 |
+
|
205 |
+
except Exception as e:
|
206 |
+
logging.error(f"Error processing request: {e}")
|
207 |
+
await query.edit_message_text(f"β Error: {str(e)}")
|
208 |
+
|
209 |
+
finally:
|
210 |
+
try:
|
211 |
+
if os.path.exists(image_path):
|
212 |
+
os.remove(image_path)
|
213 |
+
logging.info(f"Cleaned up image file: {image_path}")
|
214 |
+
except Exception as e:
|
215 |
+
logging.error(f"Error cleaning up file: {e}")
|
216 |
+
|
217 |
+
def main():
|
218 |
+
app = Application.builder().token(TOKEN).build()
|
219 |
+
app.add_handler(CommandHandler("start", start))
|
220 |
+
app.add_handler(CommandHandler("instructions", instructions))
|
221 |
+
app.add_handler(CommandHandler("list", list_categories))
|
222 |
+
app.add_handler(CommandHandler("stop", stop))
|
223 |
+
app.add_handler(MessageHandler(filters.PHOTO, handle_image))
|
224 |
+
app.add_handler(CallbackQueryHandler(handle_view, pattern="^view:"))
|
225 |
+
app.add_handler(CallbackQueryHandler(handle_category, pattern="^category:"))
|
226 |
+
app.run_polling()
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
main()
|