CXR / app.py
shubham5524's picture
Update app.py
e683b39 verified
raw
history blame
4.02 kB
from fastapi import FastAPI, Query
from pydantic import BaseModel
from typing import List, Tuple
from fastapi import Body
import torch
import torchxrayvision as xrv
import torchvision
import skimage.io
import numpy as np
import requests
import cv2
from io import BytesIO
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
app = FastAPI()
model = xrv.models.DenseNet(weights="densenet121-res224-all")
model.eval()
import os
# Set cache directory to a writable location
os.environ['TORCHXRAYVISION_CACHE'] = './torchxrayvision_cache'
os.environ['MPLCONFIGDIR'] = './.config/matplotlib' # For matplotlib font issues
def preprocess_image_from_url(image_url: str) -> torch.Tensor:
response = requests.get(image_url)
img = skimage.io.imread(BytesIO(response.content))
img = xrv.datasets.normalize(img, 255)
if img.ndim == 3:
img = img.mean(2)[None, ...]
elif img.ndim == 2:
img = img[None, ...]
transform = torchvision.transforms.Compose([
xrv.datasets.XRayCenterCrop(),
xrv.datasets.XRayResizer(224)
])
img = transform(img)
img_tensor = torch.from_numpy(img)
return img_tensor
def get_predictions_and_bounding_box(img_tensor: torch.Tensor):
with torch.no_grad():
output = model(img_tensor[None, ...])[0]
predictions = dict(zip(model.pathologies, output.numpy()))
sorted_preds = sorted(predictions.items(), key=lambda x: -x[1])
top_pred_label, top_conf = sorted_preds[0]
top_pred_index = list(model.pathologies).index(top_pred_label)
target_layer = model.features[-1]
cam = GradCAM(model=model, target_layers=[target_layer])
grayscale_cam = cam(input_tensor=img_tensor[None, ...],
targets=[ClassifierOutputTarget(top_pred_index)])[0, :]
input_img = img_tensor.numpy()[0]
input_img_norm = (input_img - input_img.min()) / (input_img.max() - input_img.min())
input_img_rgb = cv2.cvtColor((input_img_norm * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
cam_resized = cv2.resize(grayscale_cam, (224, 224))
cam_uint8 = (cam_resized * 255).astype(np.uint8)
_, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bounding_boxes = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
bounding_boxes.append((int(x), int(y), int(x + w), int(y + h)))
return sorted_preds, bounding_boxes
class Prediction(BaseModel):
label: str
confidence: float
class PredictionResponse(BaseModel):
predictions: List[Prediction]
top_prediction_bounding_boxes: List[Tuple[int, int, int, int]]
@app.get("/predict", response_model=PredictionResponse)
def predict(image_url: str = Query(..., description="URL of chest X-ray image")):
try:
img_tensor = preprocess_image_from_url(image_url)
preds, bboxes = get_predictions_and_bounding_box(img_tensor)
prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds]
return PredictionResponse(
predictions=prediction_list,
top_prediction_bounding_boxes=bboxes
)
except Exception as e:
return {"error": str(e)}
class URLRequest(BaseModel):
url: str
@app.post("/predict", response_model=PredictionResponse)
def predict_from_url(body: URLRequest):
try:
img_tensor = preprocess_image_from_url(body.url)
preds, bboxes = get_predictions_and_bounding_box(img_tensor)
prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds]
return PredictionResponse(
predictions=prediction_list,
top_prediction_bounding_boxes=bboxes
)
except Exception as e:
return {"error": str(e)}
# uvicorn app:app --reload