Spaces:
Sleeping
Sleeping
File size: 4,017 Bytes
e683b39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from fastapi import FastAPI, Query
from pydantic import BaseModel
from typing import List, Tuple
from fastapi import Body
import torch
import torchxrayvision as xrv
import torchvision
import skimage.io
import numpy as np
import requests
import cv2
from io import BytesIO
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
app = FastAPI()
model = xrv.models.DenseNet(weights="densenet121-res224-all")
model.eval()
import os
# Set cache directory to a writable location
os.environ['TORCHXRAYVISION_CACHE'] = './torchxrayvision_cache'
os.environ['MPLCONFIGDIR'] = './.config/matplotlib' # For matplotlib font issues
def preprocess_image_from_url(image_url: str) -> torch.Tensor:
response = requests.get(image_url)
img = skimage.io.imread(BytesIO(response.content))
img = xrv.datasets.normalize(img, 255)
if img.ndim == 3:
img = img.mean(2)[None, ...]
elif img.ndim == 2:
img = img[None, ...]
transform = torchvision.transforms.Compose([
xrv.datasets.XRayCenterCrop(),
xrv.datasets.XRayResizer(224)
])
img = transform(img)
img_tensor = torch.from_numpy(img)
return img_tensor
def get_predictions_and_bounding_box(img_tensor: torch.Tensor):
with torch.no_grad():
output = model(img_tensor[None, ...])[0]
predictions = dict(zip(model.pathologies, output.numpy()))
sorted_preds = sorted(predictions.items(), key=lambda x: -x[1])
top_pred_label, top_conf = sorted_preds[0]
top_pred_index = list(model.pathologies).index(top_pred_label)
target_layer = model.features[-1]
cam = GradCAM(model=model, target_layers=[target_layer])
grayscale_cam = cam(input_tensor=img_tensor[None, ...],
targets=[ClassifierOutputTarget(top_pred_index)])[0, :]
input_img = img_tensor.numpy()[0]
input_img_norm = (input_img - input_img.min()) / (input_img.max() - input_img.min())
input_img_rgb = cv2.cvtColor((input_img_norm * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
cam_resized = cv2.resize(grayscale_cam, (224, 224))
cam_uint8 = (cam_resized * 255).astype(np.uint8)
_, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bounding_boxes = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
bounding_boxes.append((int(x), int(y), int(x + w), int(y + h)))
return sorted_preds, bounding_boxes
class Prediction(BaseModel):
label: str
confidence: float
class PredictionResponse(BaseModel):
predictions: List[Prediction]
top_prediction_bounding_boxes: List[Tuple[int, int, int, int]]
@app.get("/predict", response_model=PredictionResponse)
def predict(image_url: str = Query(..., description="URL of chest X-ray image")):
try:
img_tensor = preprocess_image_from_url(image_url)
preds, bboxes = get_predictions_and_bounding_box(img_tensor)
prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds]
return PredictionResponse(
predictions=prediction_list,
top_prediction_bounding_boxes=bboxes
)
except Exception as e:
return {"error": str(e)}
class URLRequest(BaseModel):
url: str
@app.post("/predict", response_model=PredictionResponse)
def predict_from_url(body: URLRequest):
try:
img_tensor = preprocess_image_from_url(body.url)
preds, bboxes = get_predictions_and_bounding_box(img_tensor)
prediction_list = [Prediction(label=label, confidence=float(conf)) for label, conf in preds]
return PredictionResponse(
predictions=prediction_list,
top_prediction_bounding_boxes=bboxes
)
except Exception as e:
return {"error": str(e)}
# uvicorn app:app --reload
|