Some progress.
Browse files
app.py
CHANGED
@@ -35,15 +35,52 @@
|
|
35 |
|
36 |
import deepfakeecg
|
37 |
import ecg_plot
|
38 |
-
import io
|
39 |
import gradio
|
|
|
40 |
import matplotlib.pyplot as plt
|
41 |
import matplotlib.ticker
|
|
|
|
|
|
|
|
|
42 |
import torch
|
43 |
import typing
|
44 |
import PIL
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# ###### Generate ECGs ######################################################
|
48 |
def predict(numberOfECGs = 1,
|
49 |
# ecgLengthInSeconds = 10,
|
@@ -68,18 +105,19 @@ def predict(numberOfECGs = 1,
|
|
68 |
# print(matplotlib.ticker.Locator.MAXTICKS)
|
69 |
|
70 |
# ====== Generate the ECGs ===============================================
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
78 |
|
79 |
# ====== Create a list of image/label tuples for gradio.Gallery ==========
|
80 |
plotList = []
|
81 |
ecgNumber = 1
|
82 |
-
for result in
|
83 |
|
84 |
# ====== Plot ECG =====================================================
|
85 |
result = result.t().detach().cpu().numpy()
|
@@ -114,6 +152,59 @@ def predict(numberOfECGs = 1,
|
|
114 |
return plotList
|
115 |
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
# ###### Main program #######################################################
|
119 |
|
@@ -180,8 +271,15 @@ img.logo-image {
|
|
180 |
}
|
181 |
"""
|
182 |
|
|
|
183 |
# ====== Create GUI =========================================================
|
184 |
with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue)) as gui:
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
big_block = gradio.HTML("""
|
186 |
<div class="header">
|
187 |
<div class="logo-left">
|
@@ -199,12 +297,19 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
|
|
199 |
# sliderLengthInSeconds = gradio.Slider(5, 60, label="Length (s)", step = 5, value = 10, interactive = True)
|
200 |
dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True)
|
201 |
dropdownGeneratorModel = gradio.Dropdown( [ 'Default' ], label = 'Generator Model', interactive = True)
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
203 |
gradio.Markdown('## Output')
|
204 |
with gradio.Row():
|
205 |
outputGallery = gradio.Gallery(label = 'output', columns = [ 1 ], height = 'auto',
|
206 |
show_label = True,
|
207 |
preview = True)
|
|
|
|
|
208 |
|
209 |
# ====== Add click event handling for "Generate" button ==================
|
210 |
buttonGenerate.click(predict,
|
@@ -215,6 +320,13 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
|
|
215 |
outputs = [ outputGallery ]
|
216 |
)
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
# ====== Run on startup ==================================================
|
219 |
gui.load(predict,
|
220 |
inputs = [ sliderNumberOfECGs,
|
@@ -226,4 +338,6 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
|
|
226 |
|
227 |
# ====== Run the GUI ========================================================
|
228 |
if __name__ == "__main__":
|
229 |
-
|
|
|
|
|
|
35 |
|
36 |
import deepfakeecg
|
37 |
import ecg_plot
|
|
|
38 |
import gradio
|
39 |
+
import io
|
40 |
import matplotlib.pyplot as plt
|
41 |
import matplotlib.ticker
|
42 |
+
import random
|
43 |
+
import sys
|
44 |
+
import tempfile
|
45 |
+
import threading
|
46 |
import torch
|
47 |
import typing
|
48 |
import PIL
|
49 |
|
50 |
|
51 |
+
TempDirectory = None
|
52 |
+
LastResults = None
|
53 |
+
SelectedECGIndex = 0
|
54 |
+
|
55 |
+
|
56 |
+
# ###### Make a unique session ID ###########################################
|
57 |
+
SessionCounterLock = threading.Lock()
|
58 |
+
SessionCounter = 0
|
59 |
+
def generateSessionID():
|
60 |
+
global SessionCounterLock
|
61 |
+
global SessionCounter
|
62 |
+
|
63 |
+
SessionCounterLock.acquire()
|
64 |
+
SessionCounter = SessionCounter + 1
|
65 |
+
sessionID = SessionCounter
|
66 |
+
SessionCounterLock.release()
|
67 |
+
print(f'SessionID={sessionID}')
|
68 |
+
|
69 |
+
return sessionID
|
70 |
+
|
71 |
+
|
72 |
+
# ###### Get last results ###################################################
|
73 |
+
def getLastResults() -> list:
|
74 |
+
return LastResults
|
75 |
+
|
76 |
+
|
77 |
+
# ###### Get last result ####################################################
|
78 |
+
def getLastResult(index: int) -> torch.Tensor:
|
79 |
+
if LastResults != None:
|
80 |
+
return LastResults[index]
|
81 |
+
return None
|
82 |
+
|
83 |
+
|
84 |
# ###### Generate ECGs ######################################################
|
85 |
def predict(numberOfECGs = 1,
|
86 |
# ecgLengthInSeconds = 10,
|
|
|
105 |
# print(matplotlib.ticker.Locator.MAXTICKS)
|
106 |
|
107 |
# ====== Generate the ECGs ===============================================
|
108 |
+
global LastResults
|
109 |
+
LastResults = deepfakeecg.generateDeepfakeECGs(numberOfECGs,
|
110 |
+
ecgType = ecgType,
|
111 |
+
ecgLengthInSeconds = ecgLengthInSeconds,
|
112 |
+
ecgScaleFactor = 6,
|
113 |
+
outputFormat = deepfakeecg.OUTPUT_TENSOR,
|
114 |
+
showProgress = False,
|
115 |
+
runOnDevice = runOnDevice)
|
116 |
|
117 |
# ====== Create a list of image/label tuples for gradio.Gallery ==========
|
118 |
plotList = []
|
119 |
ecgNumber = 1
|
120 |
+
for result in LastResults:
|
121 |
|
122 |
# ====== Plot ECG =====================================================
|
123 |
result = result.t().detach().cpu().numpy()
|
|
|
152 |
return plotList
|
153 |
|
154 |
|
155 |
+
# ###### Select ECG in the gallery ##########################################
|
156 |
+
def select(event: gradio.SelectData):
|
157 |
+
# Get selection index from Gallery select() event:
|
158 |
+
# https://github.com/gradio-app/gradio/issues/1976#issuecomment-1726018500
|
159 |
+
|
160 |
+
global SelectedECGIndex
|
161 |
+
SelectedECGIndex = event.index
|
162 |
+
print(f'Selected #{SelectedECGIndex}!')
|
163 |
+
|
164 |
+
# return event.value
|
165 |
+
|
166 |
+
|
167 |
+
# ###### Produce CSV file from Tensor #######################################
|
168 |
+
def dataToCSV(data, outputFileName, ecgType = deepfakeecg.DATA_ECG12) -> sys.path:
|
169 |
+
|
170 |
+
data = generatedECG.detach().cpu().numpy()
|
171 |
+
|
172 |
+
if ecgType == deepfakeecg.DATA_ECG8:
|
173 |
+
header = 'Timestamp,LeadI,LeadII,V1,V2,V3,V4,V5,V6'
|
174 |
+
elif ecgType == deepfakeecg.DATA_ECG12:
|
175 |
+
header = 'Timestamp,LeadI,LeadII,V1,V2,V3,V4,V5,V6,LeadIII,aVL,aVR,aVF'
|
176 |
+
else:
|
177 |
+
raise Exception('Invalid ECG type!')
|
178 |
+
|
179 |
+
numpy.savetxt(outputFileName, data,
|
180 |
+
header = header,
|
181 |
+
comments = '',
|
182 |
+
delimiter = ',',
|
183 |
+
fmt = '%i')
|
184 |
+
|
185 |
+
|
186 |
+
# ###### Download CSV #######################################################
|
187 |
+
def downloadCSV(sessionID) -> None:
|
188 |
+
print(f'CSV #{SelectedECGIndex}!')
|
189 |
+
print(f"sessionID={sessionID}")
|
190 |
+
|
191 |
+
# ###### Download PDF #######################################################
|
192 |
+
def downloadPDF(sessionID) -> None:
|
193 |
+
print(f'PDF #{SelectedECGIndex}!')
|
194 |
+
print(f"sessionID={sessionID}")
|
195 |
+
|
196 |
+
|
197 |
+
# ###### Analyze the selected ECG ###########################################
|
198 |
+
def analyze() -> None:
|
199 |
+
|
200 |
+
print(f'Analyze #{SelectedECGIndex}!')
|
201 |
+
|
202 |
+
data = getLastResult(SelectedECGIndex)
|
203 |
+
print(data)
|
204 |
+
|
205 |
+
return None
|
206 |
+
|
207 |
+
|
208 |
|
209 |
# ###### Main program #######################################################
|
210 |
|
|
|
271 |
}
|
272 |
"""
|
273 |
|
274 |
+
|
275 |
# ====== Create GUI =========================================================
|
276 |
with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue)) as gui:
|
277 |
+
|
278 |
+
# ====== Unique session ID for this instance =============================
|
279 |
+
sessionID = gradio.State(0)
|
280 |
+
gui.load(generateSessionID, outputs = [ sessionID ])
|
281 |
+
|
282 |
+
# ====== Header ==========================================================
|
283 |
big_block = gradio.HTML("""
|
284 |
<div class="header">
|
285 |
<div class="logo-left">
|
|
|
297 |
# sliderLengthInSeconds = gradio.Slider(5, 60, label="Length (s)", step = 5, value = 10, interactive = True)
|
298 |
dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True)
|
299 |
dropdownGeneratorModel = gradio.Dropdown( [ 'Default' ], label = 'Generator Model', interactive = True)
|
300 |
+
with gradio.Column():
|
301 |
+
buttonGenerate = gradio.Button("Generate ECGs!")
|
302 |
+
buttonAnalyze = gradio.Button("Analyze this ECG!")
|
303 |
+
with gradio.Row():
|
304 |
+
buttonCSV = gradio.Button("Download CSV")
|
305 |
+
buttonPDF = gradio.Button("Download PDF")
|
306 |
gradio.Markdown('## Output')
|
307 |
with gradio.Row():
|
308 |
outputGallery = gradio.Gallery(label = 'output', columns = [ 1 ], height = 'auto',
|
309 |
show_label = True,
|
310 |
preview = True)
|
311 |
+
outputGallery.select(select)
|
312 |
+
gradio.Markdown('## Analysis')
|
313 |
|
314 |
# ====== Add click event handling for "Generate" button ==================
|
315 |
buttonGenerate.click(predict,
|
|
|
320 |
outputs = [ outputGallery ]
|
321 |
)
|
322 |
|
323 |
+
# ====== Add click event handling for "Analyze" button ===================
|
324 |
+
buttonAnalyze.click(analyze)
|
325 |
+
|
326 |
+
# ====== Add click event handling for download buttons ===================
|
327 |
+
buttonCSV.click(downloadCSV, inputs = [ sessionID ])
|
328 |
+
buttonPDF.click(downloadPDF, inputs = [ sessionID ])
|
329 |
+
|
330 |
# ====== Run on startup ==================================================
|
331 |
gui.load(predict,
|
332 |
inputs = [ sliderNumberOfECGs,
|
|
|
338 |
|
339 |
# ====== Run the GUI ========================================================
|
340 |
if __name__ == "__main__":
|
341 |
+
TempDirectory = tempfile.TemporaryDirectory('DeepFakeECGPlus')
|
342 |
+
gui.launch(allowed_paths = [ TempDirectory ])
|
343 |
+
TempDirectory.cleanup()
|