Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Query | |
from pydantic import BaseModel | |
from typing import List, Tuple | |
from fastapi import Body | |
import torch | |
import torchxrayvision as xrv | |
import torchvision | |
import skimage.io | |
import numpy as np | |
import requests | |
import cv2 | |
from io import BytesIO | |
import matplotlib.pyplot as plt | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
app = FastAPI() | |
model = xrv.models.DenseNet(weights="densenet121-res224-all") | |
model.eval() | |
import os | |
# Set cache directory to a writable location | |
os.environ['TORCHXRAYVISION_CACHE'] = './torchxrayvision_cache' | |
os.environ['MPLCONFIGDIR'] = './.config/matplotlib' # For matplotlib font issues | |
def preprocess_image_from_url(image_url: str) -> torch.Tensor: | |
response = requests.get(image_url) | |
img = skimage.io.imread(BytesIO(response.content)) | |
img = xrv.datasets.normalize(img, 255) | |
if img.ndim == 3: | |
img = img.mean(2)[None, ...] | |
elif img.ndim == 2: | |
img = img[None, ...] | |
transform = torchvision.transforms.Compose([ | |
xrv.datasets.XRayCenterCrop(), | |
xrv.datasets.XRayResizer(224) | |
]) | |
img = transform(img) | |
img_tensor = torch.from_numpy(img) | |
return img_tensor | |
def get_predictions_and_bounding_box(img_tensor: torch.Tensor): | |
with torch.no_grad(): | |
output = model(img_tensor[None, ...])[0] | |
predictions = dict(zip(model.pathologies, output.numpy())) | |
sorted_preds = sorted(predictions.items(), key=lambda x: -x[1]) | |
top_pred_label, top_conf = sorted_preds[0] | |
top_pred_index = list(model.pathologies).index(top_pred_label) | |
target_layer = model.features[-1] | |
cam = GradCAM(model=model, target_layers=[target_layer]) | |
grayscale_cam = cam(input_tensor=img_tensor[None, ...], | |
targets=[ClassifierOutputTarget(top_pred_index)])[0, :] | |
input_img = img_tensor.numpy()[0] | |
input_img_norm = (input_img - input_img.min()) / (input_img.max() - input_img.min()) | |
input_img_rgb = cv2.cvtColor((input_img_norm * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB) | |
cam_resized = cv2.resize(grayscale_cam, (224, 224)) | |
cam_uint8 = (cam_resized * 255).astype(np.uint8) | |
_, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY) | |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
bounding_boxes = [] | |
for cnt in contours: | |
x, y, w, h = cv2.boundingRect(cnt) | |
bounding_boxes.append((int(x), int(y), int(x + w), int(y + h))) | |
return sorted_preds, bounding_boxes | |
class Prediction(BaseModel): | |
label: str | |
confidence: float | |
class PredictionResponse(BaseModel): | |
predictions: List[Prediction] | |
top_prediction_bounding_boxes: List[Tuple[int, int, int, int]] | |
def predict(image_url: str = Query(..., description="URL of chest X-ray image")): | |
try: | |
img_tensor = preprocess_image_from_url(image_url) | |
preds, bboxes = get_predictions_and_bounding_box(img_tensor) | |
prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds] | |
return PredictionResponse( | |
predictions=prediction_list, | |
top_prediction_bounding_boxes=bboxes | |
) | |
except Exception as e: | |
return {"error": str(e)} | |
class URLRequest(BaseModel): | |
url: str | |
def predict_from_url(body: URLRequest): | |
try: | |
img_tensor = preprocess_image_from_url(body.url) | |
preds, bboxes = get_predictions_and_bounding_box(img_tensor) | |
prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds] | |
return PredictionResponse( | |
predictions=prediction_list, | |
top_prediction_bounding_boxes=bboxes | |
) | |
except Exception as e: | |
return {"error": str(e)} | |
# uvicorn app:app --reload | |