Video / app.py
CelagenexResearch's picture
Create app.py
aef38cb verified
raw
history blame
8.13 kB
import gradio as gr
from PIL import Image
import torch
import numpy as np
import cv2
from transformers import CLIPProcessor, CLIPModel
# Hypot MedGemma imports (ensure you have access and HF token)
from medgem import MedGemmaProcessor, MedGemmaForImageClassification
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load CLIP model for breed, age, and basic health aspects
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
# Load MedGemma for advanced medical insights
medgemma_processor = MedGemmaProcessor.from_pretrained("google/medgemma-v1").to(device)
medgemma_model = MedGemmaForImageClassification.from_pretrained("google/medgemma-v1").to(device)
# Stanford Dogs breeds list & lifespan dict
STANFORD_BREEDS = [
"afghan hound", "african hunting dog", "airedale", "american staffordshire terrier",
# ... (full list from earlier)
"wire-haired fox terrier", "yorkshire terrier"
]
BREED_LIFESPAN = {
"afghan hound": 11.1, "african hunting dog": 10.5, "airedale": 11.5,
# ... (full dict from earlier)
"yorkshire terrier": 13.3
}
# Healthspan questionnaire definitions
QUESTIONNAIRE = [
{"domain": "Mobility", "questions": [
"Does your dog have difficulty rising from lying down?",
"Does your dog hesitate before jumping up?"
]},
{"domain": "Energy", "questions": [
"Does your dog tire quickly on walks?",
"Has your dog’s activity level decreased recently?"
]},
{"domain": "Physical Health", "questions": [
"Does your dog scratch or lick skin frequently?",
"Any noticeable changes in appetite or weight?"
]},
{"domain": "Cognitive", "questions": [
"Does your dog get lost in familiar rooms?",
"Does your dog stare blankly at walls/windows?"
]},
{"domain": "Social", "questions": [
"Has your dog’s interest in play declined?",
"Does your dog avoid interaction with family?"
]}
]
# Unified scoring map for questionnaire (0–5 scale)
SCALE = ["0", "1", "2", "3", "4", "5"]
def predict_biological_age(image: Image.Image, breed: str) -> int:
avg = BREED_LIFESPAN.get(breed.lower(), 12)
prompts = [f"a photo of a {age}-year-old {breed}" for age in range(1, int(avg * 2) + 1)]
inputs = clip_processor(text=prompts, images=image, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
logits = clip_model(**inputs).logits_per_image.softmax(dim=1)[0].cpu().numpy()
return int(np.argmax(logits) + 1)
def analyze_medical_image(image: Image.Image):
inputs = medgemma_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = medgemma_model(**inputs)
probs = outputs.logits.softmax(dim=1)[0].cpu().numpy()
label = medgemma_model.config.id2label[np.argmax(probs)]
conf = float(np.max(probs))
return label, conf
def classify_breed_and_health(image: Image.Image, user_breed=None):
# Image features
inputs = clip_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
img_feats = clip_model.get_image_features(**inputs)
# Breed classification
texts = [f"a photo of a {b}" for b in STANFORD_BREEDS]
t_in = clip_processor(text=texts, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_feats = clip_model.get_text_features(**t_in)
sims = (img_feats @ text_feats.T).softmax(dim=-1)[0].cpu().numpy()
idx = sims.argmax()
breed = user_breed or STANFORD_BREEDS[idx]
breed_conf = float(sims[idx])
# Basic health aspects via CLIP
aspects = {
"Coat": ("shiny healthy coat", "dull patchy fur"),
"Eyes": ("bright clear eyes", "cloudy milky eyes"),
"Body": ("ideal muscle tone", "visible ribs or hip bones"),
"Teeth": ("clean white teeth", "yellow stained teeth")
}
health = {}
for name, (pos, neg) in aspects.items():
txt = clip_processor(text=[pos, neg], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
tf = clip_model.get_text_features(**txt)
sim = (img_feats @ tf.T).softmax(dim=-1)[0].cpu().numpy()
choice = pos if sim[0] > sim[1] else neg
health[name] = {"assessment": choice, "confidence": float(max(sim))}
return breed, breed_conf, health
def analyze_video_health(video_path: str):
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 24
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps
# sample 10 frames evenly
indices = np.linspace(0, total_frames - 1, num=10, dtype=int)
gait_scores = []
for i in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if not ret: break
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
_, conf = analyze_medical_image(img)
gait_scores.append(conf)
cap.release()
avg_conf = float(np.mean(gait_scores)) if gait_scores else 0.0
return {"duration_sec": round(duration, 1), "avg_gait_confidence": avg_conf}
def compute_questionnaire_score(answers: list):
# answers in order of QUESTIONNAIRE domains × questions
scores = {}
idx = 0
for section in QUESTIONNAIRE:
vals = list(map(int, answers[idx: idx + len(section["questions"])]))
idx += len(section["questions"])
scores[section["domain"]] = round(sum(vals) / len(vals), 2)
return scores
# Build Gradio interface
with gr.Blocks(title="🐶 Dog Health & Age Analyzer") as demo:
gr.Markdown("## Upload an Image or Video, or Record a Short Clip")
with gr.Tab("Image Analysis"):
img_in = gr.Image(type="pil", label="Upload Dog Image")
breed_in = gr.Textbox(label="(Optional) Override Breed")
age_in = gr.Number(label="Chronological Age (years)", precision=1)
btn_img = gr.Button("Analyze Image")
out_md = gr.Markdown()
def run_image(img, breed_override, chrono_age):
breed, b_conf, health = classify_breed_and_health(img, breed_override)
med_label, med_conf = analyze_medical_image(img)
bio_age = predict_biological_age(img, breed)
pace = round(bio_age / chrono_age, 2) if chrono_age else None
report = f"**Breed:** {breed} ({b_conf:.1%}) \n"
report += f"**MedGemma Finding:** {med_label} ({med_conf:.1%}) \n\n"
report += f"**Biological Age:** {bio_age} yrs \n"
report += f"**Chronological Age:** {chrono_age or 'N/A'} yrs \n"
if pace:
report += f"**Pace of Aging:** {pace}× \n\n"
report += "### Health Aspects\n"
for k, v in health.items():
report += f"- **{k}:** {v['assessment']} ({v['confidence']:.1%})\n"
return report
btn_img.click(run_image, inputs=[img_in, breed_in, age_in], outputs=out_md)
with gr.Tab("Video Analysis"):
video_in = gr.Video(label="Upload or Record Video (10–30s)")
btn_vid = gr.Button("Analyze Video")
vid_out = gr.JSON()
btn_vid.click(lambda vid: analyze_video_health(vid), inputs=video_in, outputs=vid_out)
with gr.Tab("Healthspan Questionnaire"):
widgets = []
for section in QUESTIONNAIRE:
gr.Markdown(f"### {section['domain']}")
for q in section["questions"]:
w = gr.Radio(choices=SCALE, label=q)
widgets.append(w)
btn_q = gr.Button("Compute Score")
q_out = gr.JSON()
btn_q.click(
fn=compute_questionnaire_score,
inputs=widgets,
outputs=q_out
)
with gr.Tab("About"):
gr.Markdown("""
**MedGemma v1**: Veterinary medical image analysis
**Video Module**: Gait & posture confidence score
**Questionnaire**: Healthspan domains (Mobility, Energy, Physical, Cognitive, Social)
""")
demo.launch()