File size: 8,128 Bytes
aef38cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import gradio as gr
from PIL import Image
import torch
import numpy as np
import cv2
from transformers import CLIPProcessor, CLIPModel
# Hypot MedGemma imports (ensure you have access and HF token)
from medgem import MedGemmaProcessor, MedGemmaForImageClassification

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CLIP model for breed, age, and basic health aspects
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

# Load MedGemma for advanced medical insights
medgemma_processor = MedGemmaProcessor.from_pretrained("google/medgemma-v1").to(device)
medgemma_model = MedGemmaForImageClassification.from_pretrained("google/medgemma-v1").to(device)

# Stanford Dogs breeds list & lifespan dict
STANFORD_BREEDS = [
    "afghan hound", "african hunting dog", "airedale", "american staffordshire terrier",
    # ... (full list from earlier)
    "wire-haired fox terrier", "yorkshire terrier"
]
BREED_LIFESPAN = {
    "afghan hound": 11.1, "african hunting dog": 10.5, "airedale": 11.5,
    # ... (full dict from earlier)
    "yorkshire terrier": 13.3
}

# Healthspan questionnaire definitions
QUESTIONNAIRE = [
    {"domain": "Mobility", "questions": [
        "Does your dog have difficulty rising from lying down?",
        "Does your dog hesitate before jumping up?"
    ]},
    {"domain": "Energy", "questions": [
        "Does your dog tire quickly on walks?",
        "Has your dog’s activity level decreased recently?"
    ]},
    {"domain": "Physical Health", "questions": [
        "Does your dog scratch or lick skin frequently?",
        "Any noticeable changes in appetite or weight?"
    ]},
    {"domain": "Cognitive", "questions": [
        "Does your dog get lost in familiar rooms?",
        "Does your dog stare blankly at walls/windows?"
    ]},
    {"domain": "Social", "questions": [
        "Has your dog’s interest in play declined?",
        "Does your dog avoid interaction with family?"
    ]}
]

# Unified scoring map for questionnaire (0–5 scale)
SCALE = ["0", "1", "2", "3", "4", "5"]

def predict_biological_age(image: Image.Image, breed: str) -> int:
    avg = BREED_LIFESPAN.get(breed.lower(), 12)
    prompts = [f"a photo of a {age}-year-old {breed}" for age in range(1, int(avg * 2) + 1)]
    inputs = clip_processor(text=prompts, images=image, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        logits = clip_model(**inputs).logits_per_image.softmax(dim=1)[0].cpu().numpy()
    return int(np.argmax(logits) + 1)

def analyze_medical_image(image: Image.Image):
    inputs = medgemma_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = medgemma_model(**inputs)
        probs = outputs.logits.softmax(dim=1)[0].cpu().numpy()
    label = medgemma_model.config.id2label[np.argmax(probs)]
    conf = float(np.max(probs))
    return label, conf

def classify_breed_and_health(image: Image.Image, user_breed=None):
    # Image features
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        img_feats = clip_model.get_image_features(**inputs)

    # Breed classification
    texts = [f"a photo of a {b}" for b in STANFORD_BREEDS]
    t_in = clip_processor(text=texts, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        text_feats = clip_model.get_text_features(**t_in)
    sims = (img_feats @ text_feats.T).softmax(dim=-1)[0].cpu().numpy()
    idx = sims.argmax()
    breed = user_breed or STANFORD_BREEDS[idx]
    breed_conf = float(sims[idx])

    # Basic health aspects via CLIP
    aspects = {
        "Coat": ("shiny healthy coat", "dull patchy fur"),
        "Eyes": ("bright clear eyes", "cloudy milky eyes"),
        "Body": ("ideal muscle tone", "visible ribs or hip bones"),
        "Teeth": ("clean white teeth", "yellow stained teeth")
    }
    health = {}
    for name, (pos, neg) in aspects.items():
        txt = clip_processor(text=[pos, neg], return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            tf = clip_model.get_text_features(**txt)
        sim = (img_feats @ tf.T).softmax(dim=-1)[0].cpu().numpy()
        choice = pos if sim[0] > sim[1] else neg
        health[name] = {"assessment": choice, "confidence": float(max(sim))}
    return breed, breed_conf, health

def analyze_video_health(video_path: str):
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS) or 24
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration = total_frames / fps
    # sample 10 frames evenly
    indices = np.linspace(0, total_frames - 1, num=10, dtype=int)
    gait_scores = []
    for i in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if not ret: break
        img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        _, conf = analyze_medical_image(img)
        gait_scores.append(conf)
    cap.release()
    avg_conf = float(np.mean(gait_scores)) if gait_scores else 0.0
    return {"duration_sec": round(duration, 1), "avg_gait_confidence": avg_conf}

def compute_questionnaire_score(answers: list):
    # answers in order of QUESTIONNAIRE domains × questions
    scores = {}
    idx = 0
    for section in QUESTIONNAIRE:
        vals = list(map(int, answers[idx: idx + len(section["questions"])]))
        idx += len(section["questions"])
        scores[section["domain"]] = round(sum(vals) / len(vals), 2)
    return scores

# Build Gradio interface
with gr.Blocks(title="🐶 Dog Health & Age Analyzer") as demo:
    gr.Markdown("## Upload an Image or Video, or Record a Short Clip")

    with gr.Tab("Image Analysis"):
        img_in = gr.Image(type="pil", label="Upload Dog Image")
        breed_in = gr.Textbox(label="(Optional) Override Breed")
        age_in = gr.Number(label="Chronological Age (years)", precision=1)
        btn_img = gr.Button("Analyze Image")
        out_md = gr.Markdown()

        def run_image(img, breed_override, chrono_age):
            breed, b_conf, health = classify_breed_and_health(img, breed_override)
            med_label, med_conf = analyze_medical_image(img)
            bio_age = predict_biological_age(img, breed)
            pace = round(bio_age / chrono_age, 2) if chrono_age else None

            report = f"**Breed:** {breed} ({b_conf:.1%})  \n"
            report += f"**MedGemma Finding:** {med_label} ({med_conf:.1%})  \n\n"
            report += f"**Biological Age:** {bio_age} yrs  \n"
            report += f"**Chronological Age:** {chrono_age or 'N/A'} yrs  \n"
            if pace:
                report += f"**Pace of Aging:** {pace}×  \n\n"
            report += "### Health Aspects\n"
            for k, v in health.items():
                report += f"- **{k}:** {v['assessment']} ({v['confidence']:.1%})\n"
            return report

        btn_img.click(run_image, inputs=[img_in, breed_in, age_in], outputs=out_md)

    with gr.Tab("Video Analysis"):
        video_in = gr.Video(label="Upload or Record Video (10–30s)")
        btn_vid = gr.Button("Analyze Video")
        vid_out = gr.JSON()

        btn_vid.click(lambda vid: analyze_video_health(vid), inputs=video_in, outputs=vid_out)

    with gr.Tab("Healthspan Questionnaire"):
        widgets = []
        for section in QUESTIONNAIRE:
            gr.Markdown(f"### {section['domain']}")
            for q in section["questions"]:
                w = gr.Radio(choices=SCALE, label=q)
                widgets.append(w)
        btn_q = gr.Button("Compute Score")
        q_out = gr.JSON()

        btn_q.click(
            fn=compute_questionnaire_score,
            inputs=widgets,
            outputs=q_out
        )

    with gr.Tab("About"):
        gr.Markdown("""
        **MedGemma v1**: Veterinary medical image analysis  
        **Video Module**: Gait & posture confidence score  
        **Questionnaire**: Healthspan domains (Mobility, Energy, Physical, Cognitive, Social)
        """)

demo.launch()