from fastapi import FastAPI, Query from pydantic import BaseModel from typing import List, Tuple from fastapi import Body import torch import torchxrayvision as xrv import torchvision import skimage.io import numpy as np import requests import cv2 from io import BytesIO import matplotlib.pyplot as plt from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image app = FastAPI() model = xrv.models.DenseNet(weights="densenet121-res224-all") model.eval() import os # Set cache directory to a writable location os.environ['TORCHXRAYVISION_CACHE'] = './torchxrayvision_cache' os.environ['MPLCONFIGDIR'] = './.config/matplotlib' # For matplotlib font issues def preprocess_image_from_url(image_url: str) -> torch.Tensor: response = requests.get(image_url) img = skimage.io.imread(BytesIO(response.content)) img = xrv.datasets.normalize(img, 255) if img.ndim == 3: img = img.mean(2)[None, ...] elif img.ndim == 2: img = img[None, ...] transform = torchvision.transforms.Compose([ xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224) ]) img = transform(img) img_tensor = torch.from_numpy(img) return img_tensor def get_predictions_and_bounding_box(img_tensor: torch.Tensor): with torch.no_grad(): output = model(img_tensor[None, ...])[0] predictions = dict(zip(model.pathologies, output.numpy())) sorted_preds = sorted(predictions.items(), key=lambda x: -x[1]) top_pred_label, top_conf = sorted_preds[0] top_pred_index = list(model.pathologies).index(top_pred_label) target_layer = model.features[-1] cam = GradCAM(model=model, target_layers=[target_layer]) grayscale_cam = cam(input_tensor=img_tensor[None, ...], targets=[ClassifierOutputTarget(top_pred_index)])[0, :] input_img = img_tensor.numpy()[0] input_img_norm = (input_img - input_img.min()) / (input_img.max() - input_img.min()) input_img_rgb = cv2.cvtColor((input_img_norm * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB) cam_resized = cv2.resize(grayscale_cam, (224, 224)) cam_uint8 = (cam_resized * 255).astype(np.uint8) _, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) bounding_boxes = [] for cnt in contours: x, y, w, h = cv2.boundingRect(cnt) bounding_boxes.append((int(x), int(y), int(x + w), int(y + h))) return sorted_preds, bounding_boxes class Prediction(BaseModel): label: str confidence: float class PredictionResponse(BaseModel): predictions: List[Prediction] top_prediction_bounding_boxes: List[Tuple[int, int, int, int]] @app.get("/predict", response_model=PredictionResponse) def predict(image_url: str = Query(..., description="URL of chest X-ray image")): try: img_tensor = preprocess_image_from_url(image_url) preds, bboxes = get_predictions_and_bounding_box(img_tensor) prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds] return PredictionResponse( predictions=prediction_list, top_prediction_bounding_boxes=bboxes ) except Exception as e: return {"error": str(e)} class URLRequest(BaseModel): url: str @app.post("/predict", response_model=PredictionResponse) def predict_from_url(body: URLRequest): try: img_tensor = preprocess_image_from_url(body.url) preds, bboxes = get_predictions_and_bounding_box(img_tensor) prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds] return PredictionResponse( predictions=prediction_list, top_prediction_bounding_boxes=bboxes ) except Exception as e: return {"error": str(e)} # uvicorn app:app --reload