dreibh commited on
Commit
9370c28
·
verified ·
1 Parent(s): c8548d2

Proper session handling.

Browse files
Files changed (1) hide show
  1. app.py +32 -66
app.py CHANGED
@@ -49,10 +49,8 @@ import typing
49
  import PIL
50
 
51
 
52
- TempDirectory = None
53
- Sessions = {}
54
- LastResults = None
55
- SelectedECGIndex = 0
56
 
57
 
58
  # ###### Print log message ##################################################
@@ -61,31 +59,15 @@ def log(logstring):
61
  ': ' + logstring + '\x1b[0m'));
62
 
63
 
64
- # ###### Make a unique session ID ###########################################
65
- SessionCounterLock = threading.Lock()
66
- SessionCounter = 0
67
- def generateSessionID():
68
- global SessionCounterLock
69
- global SessionCounter
70
-
71
- SessionCounterLock.acquire()
72
- SessionCounter = SessionCounter + 1
73
- sessionID = SessionCounter
74
- SessionCounterLock.release()
75
- print(f'SessionID={sessionID}')
76
-
77
- return sessionID
78
-
79
-
80
-
81
  # ###### DeepFakeECG Plus Session (session with web browser) ################
82
  class Session:
83
 
84
  # ###### Constructor #####################################################
85
  def __init__(self):
86
- self.Lock = threading.Lock()
87
- self.Counter = 0
88
- self.Results = None
 
89
 
90
  # ###### Increment counter ###############################################
91
  def increment(self):
@@ -115,27 +97,19 @@ def incrementCounter(request: gradio.Request):
115
  log(f'ERROR: Session "{request.session_hash}" is not initialized!')
116
 
117
 
118
- # ###### Get last results ###################################################
119
- def getLastResults() -> list:
120
- return LastResults
121
-
122
-
123
- # ###### Get last result ####################################################
124
- def getLastResult(index: int) -> torch.Tensor:
125
- if LastResults != None:
126
- return LastResults[index]
127
- return None
128
-
129
 
130
  # ###### Generate ECGs ######################################################
131
- def predict(numberOfECGs = 1,
132
- # ecgLengthInSeconds = 10,
133
- ecgTypeString = 'ECG-12',
134
- generatorModel = 'Default',
135
- ) -> list:
136
 
137
  ecgLengthInSeconds = 10
138
 
 
 
 
139
  # ====== Set ECG type ====================================================
140
  ecgType = deepfakeecg.DATA_ECG12
141
  if ecgTypeString == 'ECG-8':
@@ -151,19 +125,19 @@ def predict(numberOfECGs = 1,
151
  # print(matplotlib.ticker.Locator.MAXTICKS)
152
 
153
  # ====== Generate the ECGs ===============================================
154
- global LastResults
155
- LastResults = deepfakeecg.generateDeepfakeECGs(numberOfECGs,
156
- ecgType = ecgType,
157
- ecgLengthInSeconds = ecgLengthInSeconds,
158
- ecgScaleFactor = 6,
159
- outputFormat = deepfakeecg.OUTPUT_TENSOR,
160
- showProgress = False,
161
- runOnDevice = runOnDevice)
162
 
163
  # ====== Create a list of image/label tuples for gradio.Gallery ==========
164
  plotList = []
165
  ecgNumber = 1
166
- for result in LastResults:
167
 
168
  # ====== Plot ECG =====================================================
169
  result = result.t().detach().cpu().numpy()
@@ -199,15 +173,13 @@ def predict(numberOfECGs = 1,
199
 
200
 
201
  # ###### Select ECG in the gallery ##########################################
202
- def select(event: gradio.SelectData):
 
203
  # Get selection index from Gallery select() event:
204
  # https://github.com/gradio-app/gradio/issues/1976#issuecomment-1726018500
205
 
206
- global SelectedECGIndex
207
- SelectedECGIndex = event.index
208
- print(f'Selected #{SelectedECGIndex}!')
209
-
210
- # return event.value
211
 
212
 
213
  # ###### Produce CSV file from Tensor #######################################
@@ -230,25 +202,23 @@ def dataToCSV(data, outputFileName, ecgType = deepfakeecg.DATA_ECG12) -> sys.pat
230
 
231
 
232
  # ###### Download CSV #######################################################
233
- def downloadCSV(request: gradio.Request) -> None:
234
  log(f'Session "{request.session_hash}": Download CSV file')
235
 
236
 
237
  # ###### Download PDF #######################################################
238
- def downloadPDF(request: gradio.Request) -> None:
239
  log(f'Session "{request.session_hash}": Download PDF file')
240
 
241
 
242
  # ###### Analyze the selected ECG ###########################################
243
- def analyze(request: gradio.Request) -> None:
244
 
245
- log(f'Session "{request.session_hash}": Analyze #{SelectedECGIndex}!')
246
 
247
- data = getLastResult(SelectedECGIndex)
248
  print(data)
249
 
250
- return None
251
-
252
 
253
 
254
  # ###### Main program #######################################################
@@ -326,10 +296,6 @@ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.t
326
  # Session clean-up, to be called when page is closed/refreshed
327
  gui.unload(cleanUpSession)
328
 
329
- # ====== Unique session ID for this instance =============================
330
- sessionID = gradio.State(0)
331
- gui.load(generateSessionID, outputs = [ sessionID ])
332
-
333
  # ====== Header ==========================================================
334
  big_block = gradio.HTML("""
335
  <div class="header">
 
49
  import PIL
50
 
51
 
52
+ TempDirectory = None
53
+ Sessions = {}
 
 
54
 
55
 
56
  # ###### Print log message ##################################################
 
59
  ': ' + logstring + '\x1b[0m'));
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # ###### DeepFakeECG Plus Session (session with web browser) ################
63
  class Session:
64
 
65
  # ###### Constructor #####################################################
66
  def __init__(self):
67
+ self.Lock = threading.Lock()
68
+ self.Counter = 0
69
+ self.Results = None
70
+ self.Selected = 0
71
 
72
  # ###### Increment counter ###############################################
73
  def increment(self):
 
97
  log(f'ERROR: Session "{request.session_hash}" is not initialized!')
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # ###### Generate ECGs ######################################################
102
+ def predict(numberOfECGs: int = 1,
103
+ # ecgLengthInSeconds: int = 10,
104
+ ecgTypeString: str = 'ECG-12',
105
+ generatorModel: str = 'Default',
106
+ request: gradio.Request = None) -> list:
107
 
108
  ecgLengthInSeconds = 10
109
 
110
+ log(f'Session "{request.session_hash}": Generate EGCs!')
111
+
112
+
113
  # ====== Set ECG type ====================================================
114
  ecgType = deepfakeecg.DATA_ECG12
115
  if ecgTypeString == 'ECG-8':
 
125
  # print(matplotlib.ticker.Locator.MAXTICKS)
126
 
127
  # ====== Generate the ECGs ===============================================
128
+ Sessions[request.session_hash].Results = \
129
+ deepfakeecg.generateDeepfakeECGs(numberOfECGs,
130
+ ecgType = ecgType,
131
+ ecgLengthInSeconds = ecgLengthInSeconds,
132
+ ecgScaleFactor = 6,
133
+ outputFormat = deepfakeecg.OUTPUT_TENSOR,
134
+ showProgress = False,
135
+ runOnDevice = runOnDevice)
136
 
137
  # ====== Create a list of image/label tuples for gradio.Gallery ==========
138
  plotList = []
139
  ecgNumber = 1
140
+ for result in Sessions[request.session_hash].Results:
141
 
142
  # ====== Plot ECG =====================================================
143
  result = result.t().detach().cpu().numpy()
 
173
 
174
 
175
  # ###### Select ECG in the gallery ##########################################
176
+ def select(event: gradio.SelectData,
177
+ request: gradio.Request):
178
  # Get selection index from Gallery select() event:
179
  # https://github.com/gradio-app/gradio/issues/1976#issuecomment-1726018500
180
 
181
+ Sessions[request.session_hash].Selected = event.index
182
+ log(f'Session "{request.session_hash}": Selected ECG #{Sessions[request.session_hash].Selected + 1}')
 
 
 
183
 
184
 
185
  # ###### Produce CSV file from Tensor #######################################
 
202
 
203
 
204
  # ###### Download CSV #######################################################
205
+ def downloadCSV(request: gradio.Request):
206
  log(f'Session "{request.session_hash}": Download CSV file')
207
 
208
 
209
  # ###### Download PDF #######################################################
210
+ def downloadPDF(request: gradio.Request):
211
  log(f'Session "{request.session_hash}": Download PDF file')
212
 
213
 
214
  # ###### Analyze the selected ECG ###########################################
215
+ def analyze(request: gradio.Request):
216
 
217
+ log(f'Session "{request.session_hash}": Analyze ECG #{Sessions[request.session_hash].Selected + 1}!')
218
 
219
+ data = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
220
  print(data)
221
 
 
 
222
 
223
 
224
  # ###### Main program #######################################################
 
296
  # Session clean-up, to be called when page is closed/refreshed
297
  gui.unload(cleanUpSession)
298
 
 
 
 
 
299
  # ====== Header ==========================================================
300
  big_block = gradio.HTML("""
301
  <div class="header">