Requirements:

pip install transformers
pip install torch

Adapted sample script for SRRG

import torch
from PIL import Image
from transformers import BertTokenizer, ViTImageProcessor, VisionEncoderDecoderModel, GenerationConfig
import requests
import re

model_name  = "StanfordAIMI/chexpert-plus-srrg_findings"
model = VisionEncoderDecoderModel.from_pretrained(model_name).eval()
tokenizer = BertTokenizer.from_pretrained(model_name)
image_processor = ViTImageProcessor.from_pretrained(model_name)
generation_args = {
   "bos_token_id": model.config.bos_token_id,
   "eos_token_id": model.config.eos_token_id,
   "pad_token_id": model.config.pad_token_id,
   "num_return_sequences": 1,
   "max_length": 128,
   "use_cache": True,
   "beam_width": 2,
}

# Inference
with torch.no_grad():
   url = "https://huggingface.co/IAMJB/interpret-cxr-impression-baseline/resolve/main/effusions-bibasal.jpg"
   image = Image.open(requests.get(url, stream=True).raw)
   pixel_values = image_processor(image, return_tensors="pt").pixel_values
   # Generate predictions
   generated_ids = model.generate(
       pixel_values,
       generation_config=GenerationConfig(
           **{**generation_args, "decoder_start_token_id": tokenizer.cls_token_id})
   )
   generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

report = generated_texts[0]
print("Output raw report:\n", report)

# Going from one line report to multiline structured:
def process_report(report):

    VALID_ORGANS = [
        "Lungs and Airways",
        "Pleura",
        "Cardiovascular",
        "Hila and Mediastinum",
        "Tubes, Catheters, and Support Devices",
        "Musculoskeletal and Chest Wall",
        "Abdominal",
        "Other",
    ]

    # Build a regex that matches any valid organ (case-insensitive) followed by a colon.
    organ_pattern = re.compile(
        r'\b(?:' + '|'.join(re.escape(organ) for organ in VALID_ORGANS) + r')\s*:',
        re.IGNORECASE
    )

    first_found = False
    def replacement(match):
        nonlocal first_found
        if not first_found:
            first_found = True
            return match.group(0)
        else:
            return "\n\n" + match.group(0)
    return organ_pattern.sub(replacement, report)

report = report.strip()
report = report.replace(" :", ":").replace(" ,", ",").replace("-", "\n-")
report = process_report(report)
print("Formatted report:\n", report)

Output

Output raw report:
 pleura : - small bilateral pleural effusions. lungs and airways : - bibasilar opacities, left greater than right, suggestive of atelectasis or consolidation. - diffuse reticular pattern throughout the lungs. cardiovascular : - normal cardiomediastinal silhouette. musculoskeletal and chest wall : - no acute bony abnormalities.

Formatted report:
 pleura: 
- small bilateral pleural effusions. 

lungs and airways: 
- bibasilar opacities, left greater than right, suggestive of atelectasis or consolidation. 
- diffuse reticular pattern throughout the lungs. 

cardiovascular: 
- normal cardiomediastinal silhouette. 

musculoskeletal and chest wall: 
- no acute bony abnormalities.
Downloads last month
7
Safetensors
Model size
55.1M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including StanfordAIMI/chexpert-plus-srrg_findings