aliabd HF Staff commited on
Commit
2e2c116
·
1 Parent(s): 344630a

Delete app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +0 -91
app.py DELETED
@@ -1,91 +0,0 @@
1
- import os
2
- from os.path import splitext
3
- import numpy as np
4
- import sys
5
- import matplotlib.pyplot as plt
6
- import torch
7
- import torchvision
8
- import wget
9
-
10
-
11
- destination_folder = "output"
12
- destination_for_weights = "weights"
13
-
14
- if os.path.exists(destination_for_weights):
15
- print("The weights are at", destination_for_weights)
16
- else:
17
- print("Creating folder at ", destination_for_weights, " to store weights")
18
- os.mkdir(destination_for_weights)
19
-
20
- segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'
21
-
22
- if not os.path.exists(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))):
23
- print("Downloading Segmentation Weights, ", segmentationWeightsURL," to ",os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
24
- filename = wget.download(segmentationWeightsURL, out = destination_for_weights)
25
- else:
26
- print("Segmentation Weights already present")
27
-
28
- torch.cuda.empty_cache()
29
-
30
- def collate_fn(x):
31
- x, f = zip(*x)
32
- i = list(map(lambda t: t.shape[1], x))
33
- x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
34
- return x, f, i
35
-
36
- model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
37
- model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)
38
-
39
- print("loading weights from ", os.path.join(destination_for_weights, "deeplabv3_resnet50_random"))
40
-
41
- if torch.cuda.is_available():
42
- print("cuda is available, original weights")
43
- device = torch.device("cuda")
44
- model = torch.nn.DataParallel(model)
45
- model.to(device)
46
- checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
47
- model.load_state_dict(checkpoint['state_dict'])
48
- else:
49
- print("cuda is not available, cpu weights")
50
- device = torch.device("cpu")
51
- checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location = "cpu")
52
- state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
53
- model.load_state_dict(state_dict_cpu)
54
-
55
- model.eval()
56
-
57
- def segment(input):
58
- inp = input
59
- x = inp.transpose([2, 0, 1]) # channels-first
60
- x = np.expand_dims(x, axis=0) # adding a batch dimension
61
-
62
- mean = x.mean(axis=(0, 2, 3))
63
- std = x.std(axis=(0, 2, 3))
64
- x = x - mean.reshape(1, 3, 1, 1)
65
- x = x / std.reshape(1, 3, 1, 1)
66
-
67
- with torch.no_grad():
68
- x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
69
- output = model(x)
70
-
71
- y = output['out'].numpy()
72
- y = y.squeeze()
73
-
74
- out = y>0
75
-
76
- mask = inp.copy()
77
- mask[out] = np.array([0, 0, 255])
78
-
79
- return mask
80
-
81
- import gradio as gr
82
-
83
- i = gr.inputs.Image(shape=(112, 112), label="Echocardiogram")
84
- o = gr.outputs.Image(label="Segmentation Mask")
85
-
86
- examples = [["img1.jpg"], ["img2.jpg"]]
87
- title = None #"Left Ventricle Segmentation"
88
- description = "This semantic segmentation model identifies the left ventricle in echocardiogram images."
89
- # videos. Accurate evaluation of the motion and size of the left ventricle is crucial for the assessment of cardiac function and ejection fraction. In this interface, the user inputs apical-4-chamber images from echocardiography videos and the model will output a prediction of the localization of the left ventricle in blue. This model was trained on the publicly released EchoNet-Dynamic dataset of 10k echocardiogram videos with 20k expert annotations of the left ventricle and published as part of ‘Video-based AI for beat-to-beat assessment of cardiac function’ by Ouyang et al. in Nature, 2020."
90
- thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
91
- gr.Interface(segment, i, o, examples=examples, allow_flagging=False, analytics_enabled=False, thumbnail=thumbnail, cache_examples=False).launch()