Spaces:
Sleeping
Sleeping
File size: 3,715 Bytes
d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 e65d0e5 79c27a2 8c6ba75 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 8c6ba75 e65d0e5 8c6ba75 d4d94d2 e65d0e5 d4d94d2 8c6ba75 79c27a2 d4d94d2 79c27a2 a54164e 8c6ba75 d4d94d2 8c6ba75 d4d94d2 8c6ba75 d4d94d2 8c6ba75 d4d94d2 8c6ba75 d4d94d2 79c27a2 8c6ba75 79c27a2 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 8c6ba75 a54164e 8c6ba75 a54164e 8c6ba75 d4d94d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# app.py
import torch
import numpy as np
from PIL import Image
import io
import gradio as gr
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from huggingface_hub import hf_hub_download
from model import CombinedModel, ImageToTextProjector
import pydicom
import os
import gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_TOKEN = os.getenv("HF_TOKEN")
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
# Model loading
tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
video_model = models.video.r3d_18(weights="KINETICS400_V1")
video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
projector = ImageToTextProjector(512, report_generator.config.d_model)
num_classes = 4
class_names = ["acute", "normal", "chronic", "lacunar"]
combined_model = CombinedModel(video_model, report_generator, num_classes, projector, tokenizer)
model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN)
state_dict = torch.load(model_file, map_location=device)
combined_model.load_state_dict(state_dict)
combined_model.to(device)
combined_model.eval()
image_transform = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
])
def dicom_to_image(file_bytes):
dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
pixel_array = dicom_file.pixel_array.astype(np.float32)
pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0
pixel_array = pixel_array.astype(np.uint8)
return Image.fromarray(pixel_array).convert("RGB")
def predict(files):
if not files:
return "No images uploaded.", ""
processed_imgs = []
for file_obj in files:
filename = file_obj.name.lower()
if filename.endswith((".dcm", ".ima")):
file_bytes = file_obj.read()
img = dicom_to_image(file_bytes)
else:
img = Image.open(file_obj).convert("RGB")
processed_imgs.append(img)
n_frames = 16
if len(processed_imgs) >= n_frames:
images_sampled = [
processed_imgs[i]
for i in np.linspace(0, len(processed_imgs)-1, n_frames, dtype=int)
]
else:
images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs))
tensor_imgs = [image_transform(i) for i in images_sampled]
input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
with torch.no_grad():
class_logits, report, _ = combined_model(input_tensor)
class_pred = torch.argmax(class_logits, dim=1).item()
class_name = class_names[class_pred]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return class_name, report[0] if report else "No report generated."
# Gradio Blocks (100% reliable approach)
with gr.Blocks() as demo:
gr.Markdown("# 🩺 Phronesis Medical Report Generator")
upload_button = gr.UploadButton("Upload CT Scan Images", file_types=[".dcm", ".jpg", ".jpeg", ".png"], file_count="multiple")
files_state = gr.State([])
def store_files(new_files):
return new_files
upload_button.upload(store_files, upload_button, files_state)
generate_btn = gr.Button("Generate Report")
class_output = gr.Textbox(label="Predicted Class")
report_output = gr.Textbox(label="Generated Report")
generate_btn.click(predict, files_state, [class_output, report_output])
demo.launch()
|