Requirements:

pip install git+https://github.com/huggingface/transformers.git@88d960937c81a32bfb63356a2e8ecf7999619681
pip install protobuf
pip install peft
pip install sentencepiece
pip install Pillow

Adapted sample script for SRRG

from transformers import AutoModelForCausalLM, AutoProcessor, AutoConfig
from peft import PeftConfig, PeftModel
import torch
import requests
from PIL import Image

# 1. Load PEFT adapter configuration to get the base model
peft_config = PeftConfig.from_pretrained("StanfordAIMI/maira2-srrg-impression")
base_model_name = peft_config.base_model_name_or_path

# 2. Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 3. Load the model config and ensure parallelize is a dict to avoid NoneType iteration errors
config = AutoConfig.from_pretrained(
    base_model_name,
    trust_remote_code=True
)
# config.parallelize = {}

# 4. Load the base MAIRA-2 model with the patched config
model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    config=config,
    trust_remote_code=True
)

# 5. Attach the LoRA adapter
model = PeftModel.from_pretrained(
    model,
    "StanfordAIMI/maira2-srrg-impression",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)
model.eval().to(device)

# 6. Initialize the vision-language processor
global processor
processor = AutoProcessor.from_pretrained(
    base_model_name,
    trust_remote_code=True
)

# 7. Helper to fetch example chest X-rays
def get_sample_data():
    urls = {
        "frontal": "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png",
        "lateral": "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
    }
    def _download(url):
        response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
        return Image.open(response.raw).convert("RGB")
    return {
        "frontal": _download(urls["frontal"]), # we use frontal only
        "indication": "Dyspnea.", # actual indication goes here or ""
        "comparison": "None.", # actual comparison goes here or ""
        "technique": "PA and lateral views of the chest.", # actual technique goes here or ""
    }

# 8. Run a non-grounded forward pass
sample = get_sample_data()
inputs = processor.format_and_preprocess_reporting_input(
    current_frontal=sample["frontal"],
    current_lateral=None,
    prior_frontal=None,
    indication=sample["indication"],
    technique=sample["technique"],
    comparison=sample["comparison"],
    prior_report=None,
    return_tensors="pt",
    get_grounding=False
).to(device)

with torch.no_grad():
    output_ids = model.generate(
        **inputs,
        max_new_tokens=300,
        use_cache=True
    )

# 9. Decode and print impression
prompt_len = inputs["input_ids"].shape[-1]
decoded = processor.decode(output_ids[0][prompt_len:], skip_special_tokens=True).lstrip()
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded)
print("Generated Impression:\n", prediction)

Output

Generated Impression: 
1. Large right pleural effusion with associated compressive atelectasis of the right lower lobe.
2. The left lung is clear.
3. No evidence of pneumothorax.
4. No signs of pulmonary edema.
5. Cardiac and mediastinal contours are normal.
6. No acute bony abnormalities.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including StanfordAIMI/maira2-srrg-impression