dreibh commited on
Commit
ef44cbd
·
verified ·
1 Parent(s): 9370c28

Further improvements.

Browse files
Files changed (1) hide show
  1. app.py +96 -19
app.py CHANGED
@@ -40,6 +40,8 @@ import gradio
40
  import io
41
  import matplotlib.pyplot as plt
42
  import matplotlib.ticker
 
 
43
  import random
44
  import sys
45
  import tempfile
@@ -64,10 +66,24 @@ class Session:
64
 
65
  # ###### Constructor #####################################################
66
  def __init__(self):
67
- self.Lock = threading.Lock()
68
- self.Counter = 0
69
- self.Results = None
70
- self.Selected = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # ###### Increment counter ###############################################
73
  def increment(self):
@@ -133,6 +149,7 @@ def predict(numberOfECGs: int = 1,
133
  outputFormat = deepfakeecg.OUTPUT_TENSOR,
134
  showProgress = False,
135
  runOnDevice = runOnDevice)
 
136
 
137
  # ====== Create a list of image/label tuples for gradio.Gallery ==========
138
  plotList = []
@@ -182,10 +199,10 @@ def select(event: gradio.SelectData,
182
  log(f'Session "{request.session_hash}": Selected ECG #{Sessions[request.session_hash].Selected + 1}')
183
 
184
 
185
- # ###### Produce CSV file from Tensor #######################################
186
- def dataToCSV(data, outputFileName, ecgType = deepfakeecg.DATA_ECG12) -> sys.path:
187
 
188
- data = generatedECG.detach().cpu().numpy()
189
 
190
  if ecgType == deepfakeecg.DATA_ECG8:
191
  header = 'Timestamp,LeadI,LeadII,V1,V2,V3,V4,V5,V6'
@@ -195,20 +212,63 @@ def dataToCSV(data, outputFileName, ecgType = deepfakeecg.DATA_ECG12) -> sys.pat
195
  raise Exception('Invalid ECG type!')
196
 
197
  numpy.savetxt(outputFileName, data,
198
- header = header,
199
- comments = '',
200
- delimiter = ',',
201
- fmt = '%i')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
 
204
  # ###### Download CSV #######################################################
205
- def downloadCSV(request: gradio.Request):
206
- log(f'Session "{request.session_hash}": Download CSV file')
 
 
 
 
 
 
 
 
 
207
 
208
 
209
  # ###### Download PDF #######################################################
210
  def downloadPDF(request: gradio.Request):
211
- log(f'Session "{request.session_hash}": Download PDF file')
 
 
 
 
 
 
 
 
212
 
213
 
214
  # ###### Analyze the selected ECG ###########################################
@@ -318,13 +378,17 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
318
  buttonGenerate = gradio.Button("Generate ECGs!")
319
  buttonAnalyze = gradio.Button("Analyze this ECG!")
320
  with gradio.Row():
321
- buttonCSV = gradio.Button("Download CSV")
322
- buttonPDF = gradio.Button("Download PDF")
 
 
323
  gradio.Markdown('## Output')
324
  with gradio.Row():
325
- outputGallery = gradio.Gallery(label = 'output', columns = [ 1 ], height = 'auto',
 
 
326
  show_label = True,
327
- preview = True)
328
  outputGallery.select(select)
329
  gradio.Markdown('## Analysis')
330
 
@@ -341,8 +405,16 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
341
  buttonAnalyze.click(analyze)
342
 
343
  # ====== Add click event handling for download buttons ===================
 
 
344
  buttonCSV.click(downloadCSV)
 
 
 
345
  buttonPDF.click(downloadPDF)
 
 
 
346
 
347
  # ====== Run on startup ==================================================
348
  gui.load(predict,
@@ -356,9 +428,14 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
356
 
357
  # ====== Run the GUI ========================================================
358
  if __name__ == "__main__":
359
- TempDirectory = tempfile.TemporaryDirectory('DeepFakeECGPlus')
 
360
  log(f'Prepared temporary directory {TempDirectory.name}')
 
 
361
  gui.launch(allowed_paths = [ TempDirectory.name ])
 
 
362
  log(f'Cleaning up temporary directory {TempDirectory.name}')
363
  TempDirectory.cleanup()
364
  log('Done!')
 
40
  import io
41
  import matplotlib.pyplot as plt
42
  import matplotlib.ticker
43
+ import numpy
44
+ import pathlib
45
  import random
46
  import sys
47
  import tempfile
 
66
 
67
  # ###### Constructor #####################################################
68
  def __init__(self):
69
+ self.Lock = threading.Lock()
70
+ self.Counter = 0
71
+ self.Selected = 0
72
+ self.Results = None
73
+ self.Type = None
74
+ self.TempDirectory = tempfile.TemporaryDirectory(dir = TempDirectory.name)
75
+ log(f'Prepared temporary directory {self.TempDirectory.name}')
76
+
77
+ # ###### Destructor ######################################################
78
+ def __del__(self):
79
+ log(f'Cleaning up temporary directory {self.TempDirectory.name}')
80
+ self.TempDirectory.cleanup()
81
+
82
+ # ###### Increment counter ###############################################
83
+ def increment(self):
84
+ with self.lock:
85
+ self.counter += 1
86
+ return self.counter
87
 
88
  # ###### Increment counter ###############################################
89
  def increment(self):
 
149
  outputFormat = deepfakeecg.OUTPUT_TENSOR,
150
  showProgress = False,
151
  runOnDevice = runOnDevice)
152
+ Sessions[request.session_hash].Type = ecgType
153
 
154
  # ====== Create a list of image/label tuples for gradio.Gallery ==========
155
  plotList = []
 
199
  log(f'Session "{request.session_hash}": Selected ECG #{Sessions[request.session_hash].Selected + 1}')
200
 
201
 
202
+ # ###### Produce ECG CSV file from Tensor ###################################
203
+ def dataToCSV(ecgResult, ecgType, outputFileName):
204
 
205
+ data = ecgResult.detach().cpu().numpy()
206
 
207
  if ecgType == deepfakeecg.DATA_ECG8:
208
  header = 'Timestamp,LeadI,LeadII,V1,V2,V3,V4,V5,V6'
 
212
  raise Exception('Invalid ECG type!')
213
 
214
  numpy.savetxt(outputFileName, data,
215
+ header = header,
216
+ comments = '',
217
+ delimiter = ',',
218
+ fmt = '%i')
219
+
220
+
221
+ # ###### Produce ECG PDF file from Tensor ###################################
222
+ def dataToPDF(ecgResult, ecgType, outputFileName):
223
+
224
+ data = ecgResult.detach().cpu().numpy()
225
+ outputLeads = deepfakeecg.ECG_LEADS
226
+
227
+ matplotlib.pyplot.figure(figsize=(15, 3))
228
+ for outputLead in outputLeads:
229
+ try:
230
+ outputLeadIndex = deepfakeecg.ECG_LEADS[outputLead][0]
231
+ outputLeadLabel = deepfakeecg.ECG_LEADS[outputLead][1]
232
+ outputLeadType = deepfakeecg.ECG_LEADS[outputLead][2]
233
+ except:
234
+ raise Exception('Invalid lead ' + outputLead + '!')
235
+ if outputLeadType > ecgType:
236
+ raise Exception('Invalid lead ' + outputLead + ' for this ECG type!')
237
+ matplotlib.pyplot.plot(data[:, outputLeadIndex], label = outputLeadLabel)
238
+ matplotlib.pyplot.legend()
239
+ matplotlib.pyplot.title('Generated ECG — ID ' + str(i))
240
+ matplotlib.pyplot.xlabel('Time [s]')
241
+ matplotlib.pyplot.ylabel('Amplitude [μV]')
242
+ matplotlib.pyplot.grid(True)
243
+ matplotlib.pyplot.ylim(-1000, +1000)
244
+ matplotlib.pyplot.savefig(outputFileName)
245
 
246
 
247
  # ###### Download CSV #######################################################
248
+ def downloadCSV(request: gradio.Request) -> pathlib.Path:
249
+
250
+ ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
251
+ ecgType = Sessions[request.session_hash].Type
252
+ fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \
253
+ ('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.csv')
254
+ dataToCSV(ecgResult, ecgType, fileName)
255
+
256
+ log(f'Session "{request.session_hash}": Download CSV file {fileName}')
257
+ return fileName
258
+
259
 
260
 
261
  # ###### Download PDF #######################################################
262
  def downloadPDF(request: gradio.Request):
263
+
264
+ ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
265
+ ecgType = Sessions[request.session_hash].Type
266
+ fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \
267
+ ('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.pdf')
268
+ dataToPDF(ecgResult, ecgType, fileName)
269
+
270
+ log(f'Session "{request.session_hash}": Download PDF file {fileName}')
271
+ return fileName
272
 
273
 
274
  # ###### Analyze the selected ECG ###########################################
 
378
  buttonGenerate = gradio.Button("Generate ECGs!")
379
  buttonAnalyze = gradio.Button("Analyze this ECG!")
380
  with gradio.Row():
381
+ buttonCSV = gradio.DownloadButton("Download CSV")
382
+ buttonCSV_hidden = gradio.DownloadButton(visible=False, elem_id="download_csv_hidden")
383
+ buttonPDF = gradio.DownloadButton("Download PDF")
384
+ buttonPDF_hidden = gradio.DownloadButton(visible=False, elem_id="download_pdf_hidden")
385
  gradio.Markdown('## Output')
386
  with gradio.Row():
387
+ outputGallery = gradio.Gallery(label = 'output',
388
+ columns = [ 1 ],
389
+ height = 'auto',
390
  show_label = True,
391
+ preview = True)
392
  outputGallery.select(select)
393
  gradio.Markdown('## Analysis')
394
 
 
405
  buttonAnalyze.click(analyze)
406
 
407
  # ====== Add click event handling for download buttons ===================
408
+ # Using hidden button and JavaScript, to generate download file on-the-fly:
409
+ # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
410
  buttonCSV.click(downloadCSV)
411
+ buttonCSV.click(fn = downloadCSV, inputs = None, outputs = [ buttonCSV_hidden ]).then(
412
+ fn = None, inputs = None, outputs = None,
413
+ js = "() => document.querySelector('#download_csv_hidden').click()")
414
  buttonPDF.click(downloadPDF)
415
+ buttonPDF.click(fn = downloadPDF, inputs = None, outputs = [ buttonPDF_hidden ]).then(
416
+ fn = None, inputs = None, outputs = None,
417
+ js = "() => document.querySelector('#download_pdf_hidden').click()")
418
 
419
  # ====== Run on startup ==================================================
420
  gui.load(predict,
 
428
 
429
  # ====== Run the GUI ========================================================
430
  if __name__ == "__main__":
431
+ # ------ Prepare temporary directory -------------------------------------
432
+ TempDirectory = tempfile.TemporaryDirectory(prefix = 'DeepFakeECGPlus-')
433
  log(f'Prepared temporary directory {TempDirectory.name}')
434
+
435
+ # ------ Run the GUI, with downloads from temporary directory allowed ----
436
  gui.launch(allowed_paths = [ TempDirectory.name ])
437
+
438
+ # ------ Clean up --------------------------------------------------------
439
  log(f'Cleaning up temporary directory {TempDirectory.name}')
440
  TempDirectory.cleanup()
441
  log('Done!')