dreibh commited on
Commit
429321a
·
verified ·
1 Parent(s): dfaac8a

Some progress.

Browse files
Files changed (1) hide show
  1. app.py +125 -11
app.py CHANGED
@@ -35,15 +35,52 @@
35
 
36
  import deepfakeecg
37
  import ecg_plot
38
- import io
39
  import gradio
 
40
  import matplotlib.pyplot as plt
41
  import matplotlib.ticker
 
 
 
 
42
  import torch
43
  import typing
44
  import PIL
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # ###### Generate ECGs ######################################################
48
  def predict(numberOfECGs = 1,
49
  # ecgLengthInSeconds = 10,
@@ -68,18 +105,19 @@ def predict(numberOfECGs = 1,
68
  # print(matplotlib.ticker.Locator.MAXTICKS)
69
 
70
  # ====== Generate the ECGs ===============================================
71
- results = deepfakeecg.generateDeepfakeECGs(numberOfECGs,
72
- ecgType = ecgType,
73
- ecgLengthInSeconds = ecgLengthInSeconds,
74
- ecgScaleFactor = 6,
75
- outputFormat = deepfakeecg.OUTPUT_TENSOR,
76
- showProgress = False,
77
- runOnDevice = runOnDevice)
 
78
 
79
  # ====== Create a list of image/label tuples for gradio.Gallery ==========
80
  plotList = []
81
  ecgNumber = 1
82
- for result in results:
83
 
84
  # ====== Plot ECG =====================================================
85
  result = result.t().detach().cpu().numpy()
@@ -114,6 +152,59 @@ def predict(numberOfECGs = 1,
114
  return plotList
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # ###### Main program #######################################################
119
 
@@ -180,8 +271,15 @@ img.logo-image {
180
  }
181
  """
182
 
 
183
  # ====== Create GUI =========================================================
184
  with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue)) as gui:
 
 
 
 
 
 
185
  big_block = gradio.HTML("""
186
  <div class="header">
187
  <div class="logo-left">
@@ -199,12 +297,19 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
199
  # sliderLengthInSeconds = gradio.Slider(5, 60, label="Length (s)", step = 5, value = 10, interactive = True)
200
  dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True)
201
  dropdownGeneratorModel = gradio.Dropdown( [ 'Default' ], label = 'Generator Model', interactive = True)
202
- buttonGenerate = gradio.Button("Generate")
 
 
 
 
 
203
  gradio.Markdown('## Output')
204
  with gradio.Row():
205
  outputGallery = gradio.Gallery(label = 'output', columns = [ 1 ], height = 'auto',
206
  show_label = True,
207
  preview = True)
 
 
208
 
209
  # ====== Add click event handling for "Generate" button ==================
210
  buttonGenerate.click(predict,
@@ -215,6 +320,13 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
215
  outputs = [ outputGallery ]
216
  )
217
 
 
 
 
 
 
 
 
218
  # ====== Run on startup ==================================================
219
  gui.load(predict,
220
  inputs = [ sliderNumberOfECGs,
@@ -226,4 +338,6 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
226
 
227
  # ====== Run the GUI ========================================================
228
  if __name__ == "__main__":
229
- gui.launch()
 
 
 
35
 
36
  import deepfakeecg
37
  import ecg_plot
 
38
  import gradio
39
+ import io
40
  import matplotlib.pyplot as plt
41
  import matplotlib.ticker
42
+ import random
43
+ import sys
44
+ import tempfile
45
+ import threading
46
  import torch
47
  import typing
48
  import PIL
49
 
50
 
51
+ TempDirectory = None
52
+ LastResults = None
53
+ SelectedECGIndex = 0
54
+
55
+
56
+ # ###### Make a unique session ID ###########################################
57
+ SessionCounterLock = threading.Lock()
58
+ SessionCounter = 0
59
+ def generateSessionID():
60
+ global SessionCounterLock
61
+ global SessionCounter
62
+
63
+ SessionCounterLock.acquire()
64
+ SessionCounter = SessionCounter + 1
65
+ sessionID = SessionCounter
66
+ SessionCounterLock.release()
67
+ print(f'SessionID={sessionID}')
68
+
69
+ return sessionID
70
+
71
+
72
+ # ###### Get last results ###################################################
73
+ def getLastResults() -> list:
74
+ return LastResults
75
+
76
+
77
+ # ###### Get last result ####################################################
78
+ def getLastResult(index: int) -> torch.Tensor:
79
+ if LastResults != None:
80
+ return LastResults[index]
81
+ return None
82
+
83
+
84
  # ###### Generate ECGs ######################################################
85
  def predict(numberOfECGs = 1,
86
  # ecgLengthInSeconds = 10,
 
105
  # print(matplotlib.ticker.Locator.MAXTICKS)
106
 
107
  # ====== Generate the ECGs ===============================================
108
+ global LastResults
109
+ LastResults = deepfakeecg.generateDeepfakeECGs(numberOfECGs,
110
+ ecgType = ecgType,
111
+ ecgLengthInSeconds = ecgLengthInSeconds,
112
+ ecgScaleFactor = 6,
113
+ outputFormat = deepfakeecg.OUTPUT_TENSOR,
114
+ showProgress = False,
115
+ runOnDevice = runOnDevice)
116
 
117
  # ====== Create a list of image/label tuples for gradio.Gallery ==========
118
  plotList = []
119
  ecgNumber = 1
120
+ for result in LastResults:
121
 
122
  # ====== Plot ECG =====================================================
123
  result = result.t().detach().cpu().numpy()
 
152
  return plotList
153
 
154
 
155
+ # ###### Select ECG in the gallery ##########################################
156
+ def select(event: gradio.SelectData):
157
+ # Get selection index from Gallery select() event:
158
+ # https://github.com/gradio-app/gradio/issues/1976#issuecomment-1726018500
159
+
160
+ global SelectedECGIndex
161
+ SelectedECGIndex = event.index
162
+ print(f'Selected #{SelectedECGIndex}!')
163
+
164
+ # return event.value
165
+
166
+
167
+ # ###### Produce CSV file from Tensor #######################################
168
+ def dataToCSV(data, outputFileName, ecgType = deepfakeecg.DATA_ECG12) -> sys.path:
169
+
170
+ data = generatedECG.detach().cpu().numpy()
171
+
172
+ if ecgType == deepfakeecg.DATA_ECG8:
173
+ header = 'Timestamp,LeadI,LeadII,V1,V2,V3,V4,V5,V6'
174
+ elif ecgType == deepfakeecg.DATA_ECG12:
175
+ header = 'Timestamp,LeadI,LeadII,V1,V2,V3,V4,V5,V6,LeadIII,aVL,aVR,aVF'
176
+ else:
177
+ raise Exception('Invalid ECG type!')
178
+
179
+ numpy.savetxt(outputFileName, data,
180
+ header = header,
181
+ comments = '',
182
+ delimiter = ',',
183
+ fmt = '%i')
184
+
185
+
186
+ # ###### Download CSV #######################################################
187
+ def downloadCSV(sessionID) -> None:
188
+ print(f'CSV #{SelectedECGIndex}!')
189
+ print(f"sessionID={sessionID}")
190
+
191
+ # ###### Download PDF #######################################################
192
+ def downloadPDF(sessionID) -> None:
193
+ print(f'PDF #{SelectedECGIndex}!')
194
+ print(f"sessionID={sessionID}")
195
+
196
+
197
+ # ###### Analyze the selected ECG ###########################################
198
+ def analyze() -> None:
199
+
200
+ print(f'Analyze #{SelectedECGIndex}!')
201
+
202
+ data = getLastResult(SelectedECGIndex)
203
+ print(data)
204
+
205
+ return None
206
+
207
+
208
 
209
  # ###### Main program #######################################################
210
 
 
271
  }
272
  """
273
 
274
+
275
  # ====== Create GUI =========================================================
276
  with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue)) as gui:
277
+
278
+ # ====== Unique session ID for this instance =============================
279
+ sessionID = gradio.State(0)
280
+ gui.load(generateSessionID, outputs = [ sessionID ])
281
+
282
+ # ====== Header ==========================================================
283
  big_block = gradio.HTML("""
284
  <div class="header">
285
  <div class="logo-left">
 
297
  # sliderLengthInSeconds = gradio.Slider(5, 60, label="Length (s)", step = 5, value = 10, interactive = True)
298
  dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True)
299
  dropdownGeneratorModel = gradio.Dropdown( [ 'Default' ], label = 'Generator Model', interactive = True)
300
+ with gradio.Column():
301
+ buttonGenerate = gradio.Button("Generate ECGs!")
302
+ buttonAnalyze = gradio.Button("Analyze this ECG!")
303
+ with gradio.Row():
304
+ buttonCSV = gradio.Button("Download CSV")
305
+ buttonPDF = gradio.Button("Download PDF")
306
  gradio.Markdown('## Output')
307
  with gradio.Row():
308
  outputGallery = gradio.Gallery(label = 'output', columns = [ 1 ], height = 'auto',
309
  show_label = True,
310
  preview = True)
311
+ outputGallery.select(select)
312
+ gradio.Markdown('## Analysis')
313
 
314
  # ====== Add click event handling for "Generate" button ==================
315
  buttonGenerate.click(predict,
 
320
  outputs = [ outputGallery ]
321
  )
322
 
323
+ # ====== Add click event handling for "Analyze" button ===================
324
+ buttonAnalyze.click(analyze)
325
+
326
+ # ====== Add click event handling for download buttons ===================
327
+ buttonCSV.click(downloadCSV, inputs = [ sessionID ])
328
+ buttonPDF.click(downloadPDF, inputs = [ sessionID ])
329
+
330
  # ====== Run on startup ==================================================
331
  gui.load(predict,
332
  inputs = [ sliderNumberOfECGs,
 
338
 
339
  # ====== Run the GUI ========================================================
340
  if __name__ == "__main__":
341
+ TempDirectory = tempfile.TemporaryDirectory('DeepFakeECGPlus')
342
+ gui.launch(allowed_paths = [ TempDirectory ])
343
+ TempDirectory.cleanup()