CelagenexResearch commited on
Commit
aef38cb
·
verified ·
1 Parent(s): eeadc66

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ from transformers import CLIPProcessor, CLIPModel
7
+ # Hypot MedGemma imports (ensure you have access and HF token)
8
+ from medgem import MedGemmaProcessor, MedGemmaForImageClassification
9
+
10
+ # Device setup
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # Load CLIP model for breed, age, and basic health aspects
14
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
15
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
16
+
17
+ # Load MedGemma for advanced medical insights
18
+ medgemma_processor = MedGemmaProcessor.from_pretrained("google/medgemma-v1").to(device)
19
+ medgemma_model = MedGemmaForImageClassification.from_pretrained("google/medgemma-v1").to(device)
20
+
21
+ # Stanford Dogs breeds list & lifespan dict
22
+ STANFORD_BREEDS = [
23
+ "afghan hound", "african hunting dog", "airedale", "american staffordshire terrier",
24
+ # ... (full list from earlier)
25
+ "wire-haired fox terrier", "yorkshire terrier"
26
+ ]
27
+ BREED_LIFESPAN = {
28
+ "afghan hound": 11.1, "african hunting dog": 10.5, "airedale": 11.5,
29
+ # ... (full dict from earlier)
30
+ "yorkshire terrier": 13.3
31
+ }
32
+
33
+ # Healthspan questionnaire definitions
34
+ QUESTIONNAIRE = [
35
+ {"domain": "Mobility", "questions": [
36
+ "Does your dog have difficulty rising from lying down?",
37
+ "Does your dog hesitate before jumping up?"
38
+ ]},
39
+ {"domain": "Energy", "questions": [
40
+ "Does your dog tire quickly on walks?",
41
+ "Has your dog’s activity level decreased recently?"
42
+ ]},
43
+ {"domain": "Physical Health", "questions": [
44
+ "Does your dog scratch or lick skin frequently?",
45
+ "Any noticeable changes in appetite or weight?"
46
+ ]},
47
+ {"domain": "Cognitive", "questions": [
48
+ "Does your dog get lost in familiar rooms?",
49
+ "Does your dog stare blankly at walls/windows?"
50
+ ]},
51
+ {"domain": "Social", "questions": [
52
+ "Has your dog’s interest in play declined?",
53
+ "Does your dog avoid interaction with family?"
54
+ ]}
55
+ ]
56
+
57
+ # Unified scoring map for questionnaire (0–5 scale)
58
+ SCALE = ["0", "1", "2", "3", "4", "5"]
59
+
60
+ def predict_biological_age(image: Image.Image, breed: str) -> int:
61
+ avg = BREED_LIFESPAN.get(breed.lower(), 12)
62
+ prompts = [f"a photo of a {age}-year-old {breed}" for age in range(1, int(avg * 2) + 1)]
63
+ inputs = clip_processor(text=prompts, images=image, return_tensors="pt", padding=True).to(device)
64
+ with torch.no_grad():
65
+ logits = clip_model(**inputs).logits_per_image.softmax(dim=1)[0].cpu().numpy()
66
+ return int(np.argmax(logits) + 1)
67
+
68
+ def analyze_medical_image(image: Image.Image):
69
+ inputs = medgemma_processor(images=image, return_tensors="pt").to(device)
70
+ with torch.no_grad():
71
+ outputs = medgemma_model(**inputs)
72
+ probs = outputs.logits.softmax(dim=1)[0].cpu().numpy()
73
+ label = medgemma_model.config.id2label[np.argmax(probs)]
74
+ conf = float(np.max(probs))
75
+ return label, conf
76
+
77
+ def classify_breed_and_health(image: Image.Image, user_breed=None):
78
+ # Image features
79
+ inputs = clip_processor(images=image, return_tensors="pt").to(device)
80
+ with torch.no_grad():
81
+ img_feats = clip_model.get_image_features(**inputs)
82
+
83
+ # Breed classification
84
+ texts = [f"a photo of a {b}" for b in STANFORD_BREEDS]
85
+ t_in = clip_processor(text=texts, return_tensors="pt", padding=True).to(device)
86
+ with torch.no_grad():
87
+ text_feats = clip_model.get_text_features(**t_in)
88
+ sims = (img_feats @ text_feats.T).softmax(dim=-1)[0].cpu().numpy()
89
+ idx = sims.argmax()
90
+ breed = user_breed or STANFORD_BREEDS[idx]
91
+ breed_conf = float(sims[idx])
92
+
93
+ # Basic health aspects via CLIP
94
+ aspects = {
95
+ "Coat": ("shiny healthy coat", "dull patchy fur"),
96
+ "Eyes": ("bright clear eyes", "cloudy milky eyes"),
97
+ "Body": ("ideal muscle tone", "visible ribs or hip bones"),
98
+ "Teeth": ("clean white teeth", "yellow stained teeth")
99
+ }
100
+ health = {}
101
+ for name, (pos, neg) in aspects.items():
102
+ txt = clip_processor(text=[pos, neg], return_tensors="pt", padding=True).to(device)
103
+ with torch.no_grad():
104
+ tf = clip_model.get_text_features(**txt)
105
+ sim = (img_feats @ tf.T).softmax(dim=-1)[0].cpu().numpy()
106
+ choice = pos if sim[0] > sim[1] else neg
107
+ health[name] = {"assessment": choice, "confidence": float(max(sim))}
108
+ return breed, breed_conf, health
109
+
110
+ def analyze_video_health(video_path: str):
111
+ cap = cv2.VideoCapture(video_path)
112
+ fps = cap.get(cv2.CAP_PROP_FPS) or 24
113
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
114
+ duration = total_frames / fps
115
+ # sample 10 frames evenly
116
+ indices = np.linspace(0, total_frames - 1, num=10, dtype=int)
117
+ gait_scores = []
118
+ for i in indices:
119
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
120
+ ret, frame = cap.read()
121
+ if not ret: break
122
+ img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
123
+ _, conf = analyze_medical_image(img)
124
+ gait_scores.append(conf)
125
+ cap.release()
126
+ avg_conf = float(np.mean(gait_scores)) if gait_scores else 0.0
127
+ return {"duration_sec": round(duration, 1), "avg_gait_confidence": avg_conf}
128
+
129
+ def compute_questionnaire_score(answers: list):
130
+ # answers in order of QUESTIONNAIRE domains × questions
131
+ scores = {}
132
+ idx = 0
133
+ for section in QUESTIONNAIRE:
134
+ vals = list(map(int, answers[idx: idx + len(section["questions"])]))
135
+ idx += len(section["questions"])
136
+ scores[section["domain"]] = round(sum(vals) / len(vals), 2)
137
+ return scores
138
+
139
+ # Build Gradio interface
140
+ with gr.Blocks(title="🐶 Dog Health & Age Analyzer") as demo:
141
+ gr.Markdown("## Upload an Image or Video, or Record a Short Clip")
142
+
143
+ with gr.Tab("Image Analysis"):
144
+ img_in = gr.Image(type="pil", label="Upload Dog Image")
145
+ breed_in = gr.Textbox(label="(Optional) Override Breed")
146
+ age_in = gr.Number(label="Chronological Age (years)", precision=1)
147
+ btn_img = gr.Button("Analyze Image")
148
+ out_md = gr.Markdown()
149
+
150
+ def run_image(img, breed_override, chrono_age):
151
+ breed, b_conf, health = classify_breed_and_health(img, breed_override)
152
+ med_label, med_conf = analyze_medical_image(img)
153
+ bio_age = predict_biological_age(img, breed)
154
+ pace = round(bio_age / chrono_age, 2) if chrono_age else None
155
+
156
+ report = f"**Breed:** {breed} ({b_conf:.1%}) \n"
157
+ report += f"**MedGemma Finding:** {med_label} ({med_conf:.1%}) \n\n"
158
+ report += f"**Biological Age:** {bio_age} yrs \n"
159
+ report += f"**Chronological Age:** {chrono_age or 'N/A'} yrs \n"
160
+ if pace:
161
+ report += f"**Pace of Aging:** {pace}× \n\n"
162
+ report += "### Health Aspects\n"
163
+ for k, v in health.items():
164
+ report += f"- **{k}:** {v['assessment']} ({v['confidence']:.1%})\n"
165
+ return report
166
+
167
+ btn_img.click(run_image, inputs=[img_in, breed_in, age_in], outputs=out_md)
168
+
169
+ with gr.Tab("Video Analysis"):
170
+ video_in = gr.Video(label="Upload or Record Video (10–30s)")
171
+ btn_vid = gr.Button("Analyze Video")
172
+ vid_out = gr.JSON()
173
+
174
+ btn_vid.click(lambda vid: analyze_video_health(vid), inputs=video_in, outputs=vid_out)
175
+
176
+ with gr.Tab("Healthspan Questionnaire"):
177
+ widgets = []
178
+ for section in QUESTIONNAIRE:
179
+ gr.Markdown(f"### {section['domain']}")
180
+ for q in section["questions"]:
181
+ w = gr.Radio(choices=SCALE, label=q)
182
+ widgets.append(w)
183
+ btn_q = gr.Button("Compute Score")
184
+ q_out = gr.JSON()
185
+
186
+ btn_q.click(
187
+ fn=compute_questionnaire_score,
188
+ inputs=widgets,
189
+ outputs=q_out
190
+ )
191
+
192
+ with gr.Tab("About"):
193
+ gr.Markdown("""
194
+ **MedGemma v1**: Veterinary medical image analysis
195
+ **Video Module**: Gait & posture confidence score
196
+ **Questionnaire**: Healthspan domains (Mobility, Energy, Physical, Cognitive, Social)
197
+ """)
198
+
199
+ demo.launch()