dreibh commited on
Commit
a72a986
·
verified ·
1 Parent(s): 3cd4799

Improved GUI.

Browse files
Files changed (3) hide show
  1. README.md +19 -1
  2. app.py +214 -33
  3. requirements.txt +3 -3
README.md CHANGED
@@ -11,6 +11,24 @@ license: cc-by-4.0
11
  short_description: Fake ECG Generator
12
  ---
13
 
 
 
14
  Allows to generate ECGs. Based on the following paper:
15
 
16
- https://www.nature.com/articles/s41598-021-01295-2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  short_description: Fake ECG Generator
12
  ---
13
 
14
+ # Deepfake ECG Generator GUI
15
+
16
  Allows to generate ECGs. Based on the following paper:
17
 
18
+ https://www.nature.com/articles/s41598-021-01295-2
19
+
20
+ # Run locally
21
+
22
+ ## Prepare venv and install dependencies
23
+ ```bash
24
+ mkdir -p ~/python-environments/deepfake-ecg
25
+ python3 -m venv ~/python-environments/deepfake-ecg
26
+ . ~/python-environments/deepfake-ecg/bin/activate
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ ## Run the application
31
+ ```bash
32
+ ./app.py
33
+ ```
34
+ Then, connect a web browser to [http://127.0.0.1:7860/](http://127.0.0.1:7860/) to use the application
app.py CHANGED
@@ -1,36 +1,217 @@
1
- import io
2
- import gradio as gr
3
- #from transformers import pipeline
4
- from transformers import AutoModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import ecg_plot
 
 
6
  import matplotlib.pyplot as plt
7
- from PIL import Image
 
8
  import torch
9
- #pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
10
- model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True)
11
-
12
- def predict():
13
- prediction = (model(1)[0].t()/1000) # to micro volte
14
-
15
-
16
- lead_III = (prediction[1] - prediction[0]).unsqueeze(dim=0)
17
- lead_aVR = ((prediction[0] + prediction[1])*(-0.5)).unsqueeze(dim=0)
18
- lead_aVL = (prediction[0] - prediction[1]* 0.5).unsqueeze(dim=0)
19
- lead_aVF = (prediction[1] - prediction[0]* 0.5).unsqueeze(dim=0)
20
- all = torch.cat((prediction, lead_III, lead_aVR, lead_aVL, lead_aVF), dim=0)
21
- all_corrected = all[torch.tensor([0,1,8, 9, 10, 11, 2,3,4,5,6,7])]
22
-
23
- ecg_plot.plot(all_corrected, sample_rate = 500, title = 'ECG 12')
24
-
25
- #ecg_plot.show()
26
- buf = io.BytesIO()
27
- plt.savefig(buf, format="png")
28
- img = Image.open(buf)
29
- return img
30
-
31
- gr.Interface(
32
- predict,
33
- inputs=None,
34
- outputs="image",
35
- title="Generating Fake ECGs",
36
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # ==========================================================================
4
+ # ____ __ _ _____ ____ ____
5
+ # | _ \ ___ ___ _ __ / _| __ _| | _____ | ____/ ___/ ___|
6
+ # | | | |/ _ \/ _ \ '_ \| |_ / _` | |/ / _ \ | _|| | | | _
7
+ # | |_| | __/ __/ |_) | _| (_| | < __/ | |__| |__| |_| |
8
+ # |____/ \___|\___| .__/|_| \__,_|_|\_\___| |_____\____\____|
9
+ # |_|
10
+ #
11
+ # --- Deepfake ECG Generator ---
12
+ # https://github.com/vlbthambawita/deepfake-ecg
13
+ # ==========================================================================
14
+ #
15
+ # DeepfakeECG GUI Application
16
+ # Copyright (C) 2023-2025 by Vajira Thambawita
17
+ # Copyright (C) 2025 by Thomas Dreibholz
18
+ #
19
+ # This program is free software: you can redistribute it and/or modify
20
+ # it under the terms of the GNU General Public License as published by
21
+ # the Free Software Foundation, either version 3 of the License, or
22
+ # (at your option) any later version.
23
+
24
+ # This program is distributed in the hope that it will be useful,
25
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
26
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
+ # GNU General Public License for more details.
28
+ #
29
+ # You should have received a copy of the GNU General Public License
30
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
31
+ #
32
+ # Contact:
33
+ # * Vajira Thambawita <[email protected]>
34
+ # * Thomas Dreibholz <[email protected]>
35
+
36
+ import deepfakeecg
37
  import ecg_plot
38
+ import io
39
+ import gradio
40
  import matplotlib.pyplot as plt
41
+ import matplotlib.ticker
42
+ import transformers
43
  import torch
44
+ import typing
45
+ import PIL
46
+
47
+
48
+ # ###### Generate ECGs ######################################################
49
+ def predict(numberOfECGs = 1, ecgLengthInSeconds = 10, ecgTypeString = 'ECG-12') -> list:
50
+
51
+ # ====== Set ECG type ====================================================
52
+ ecgType = deepfakeecg.DATA_ECG12
53
+ if ecgTypeString == 'ECG-8':
54
+ ecgType = deepfakeecg.DATA_ECG8
55
+ elif ecgTypeString == 'ECG-12':
56
+ ecgType = deepfakeecg.DATA_ECG12
57
+ else:
58
+ sys.stderr.write(f'WARNING: Invalid ecgTypeString {ecgTypeString}, using ECG-12!\n')
59
+
60
+ # ====== Raise Locator.MAXTICKS, if necessary ============================
61
+ matplotlib.ticker.Locator.MAXTICKS = \
62
+ max(1000, ecgLengthInSeconds * deepfakeecg.ECG_SAMPLING_RATE)
63
+ # print(matplotlib.ticker.Locator.MAXTICKS)
64
+
65
+ # ====== Generate the ECGs ===============================================
66
+ results = deepfakeecg.generateDeepfakeECGs(numberOfECGs,
67
+ ecgType = ecgType,
68
+ ecgLengthInSeconds = ecgLengthInSeconds,
69
+ ecgScaleFactor = 6,
70
+ outputFormat = deepfakeecg.OUTPUT_TENSOR,
71
+ showProgress = False,
72
+ runOnDevice = runOnDevice)
73
+
74
+ # ====== Create a list of image/label tuples for gradio.Gallery ==========
75
+ plotList = []
76
+ ecgNumber = 1
77
+ for result in results:
78
+
79
+ # ====== Plot ECG =====================================================
80
+ result = result.t().detach().cpu().numpy()
81
+ # print(result)
82
+
83
+ # ------ ECG-12 -------------------------------------------------------
84
+ if ecgType == deepfakeecg.DATA_ECG12:
85
+ ecg_plot.plot(result,
86
+ title = 'ECG-12',
87
+ sample_rate = deepfakeecg.ECG_SAMPLING_RATE,
88
+ lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'III', 'aVR', 'aVL', 'aVF' ],
89
+ lead_order = [0, 1, 8, 9, 10, 11, 2, 3, 4, 5, 6, 7],
90
+ show_grid = True)
91
+ # ------ ECG-8 --------------------------------------------------------
92
+ else:
93
+ ecg_plot.plot(result,
94
+ title = 'ECG-8',
95
+ sample_rate = deepfakeecg.ECG_SAMPLING_RATE,
96
+ lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6' ],
97
+ lead_order = [0, 1, 2, 3, 4, 5, 6, 7],
98
+ show_grid = True)
99
+
100
+ # ====== Generate WebP output =========================================
101
+ imageBuffer = io.BytesIO()
102
+ plt.savefig(imageBuffer, format = 'webp')
103
+ plt.close()
104
+ image = PIL.Image.open(imageBuffer)
105
+ plotList.append( (image, f'ECG Number {ecgNumber}') )
106
+
107
+ ecgNumber = ecgNumber + 1
108
+
109
+ return plotList
110
+
111
+
112
+
113
+ # ###### Main program #######################################################
114
+
115
+ # ====== Initialise =========================================================
116
+ runOnDevice: typing.Literal['cpu', 'cuda'] = 'cuda' if torch.cuda.is_available() else 'cpu'
117
+ css = r"""
118
+ div {
119
+ background-image: url("https://www.nntb.no/~dreibh/graphics/backgrounds/background-essen.png");
120
+ }
121
+
122
+ /* ###### General Settings ############################################## */
123
+ html, body {
124
+ height: 100%;
125
+ padding: 0;
126
+ margin: 0;
127
+ font-family: sans-serif;
128
+ font-size: small;
129
+ background-color: #E3E3E3; /* Simula background colour: #E3E3E3 */
130
+ }
131
+
132
+
133
+ /* ###### Header ######################################################## */
134
+ div.header {
135
+ background-image: none;
136
+ background-color: #F15D22; /* Simula header colour: #F15D22 */
137
+ height: 7.5%;
138
+ display: flex;
139
+ justify-content: space-between;
140
+ }
141
+
142
+ div.logo-left {
143
+ width: 12.5%;
144
+ float: left;
145
+ display: flex;
146
+ padding: 0% 1%;
147
+ align-items: center;
148
+ background: white;
149
+ }
150
+
151
+ div.logo-right {
152
+ width: 12.5%;
153
+ float: right;
154
+ display: flex;
155
+ padding: 0% 1%;
156
+ align-items: center;
157
+ background: white;
158
+ }
159
+
160
+ div.title {
161
+ display: flex;
162
+ align-items: center;
163
+ padding: 0% 1%;
164
+ background-image: none;
165
+ background-color: #F15D22; /* Simula header colour: #F15D22 */
166
+
167
+ font-family: "Ubuntu", sans-serif;
168
+ font-size: 4vh;
169
+ font-weight: bold;
170
+ }r
171
+
172
+ img.logo-image {
173
+ max-width: 100%;
174
+ max-height: 100%;
175
+ }
176
+ """
177
+
178
+ # ====== Create GUI =========================================================
179
+ with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue)) as gui:
180
+ big_block = gradio.HTML("""
181
+ <div class="header">
182
+ <div class="logo-left">
183
+ <img class="logo-image" src="" alt="SimulaMet" height="32" />
184
+ </div>
185
+ <div class="title" id="title"><a href="https://ihi-search.eu/">SEARCH</a>&nbsp;Fake ECG Generator</div>
186
+ <div class="logo-right">
187
+ <img class="logo-image" src="" alt="NorNet" height="64" />
188
+ </div>
189
+ </div>
190
+ """)
191
+ gradio.Markdown('## Settings')
192
+ with gradio.Row():
193
+ sliderNumberOfECGs = gradio.Slider(1, 100, label="Number of ECGs", step = 1, value = 4, interactive = True)
194
+ sliderLengthInSeconds = gradio.Slider(5, 60, label="Length (s)", step = 5, value = 10, interactive = True)
195
+ dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True)
196
+ buttonGenerate = gradio.Button("Generate")
197
+ gradio.Markdown('## Output')
198
+ with gradio.Row():
199
+ outputGallery = gradio.Gallery(label = 'output', columns = [ 1 ], height = 'auto',
200
+ show_label = True,
201
+ preview = True)
202
+
203
+ # ====== Add click event handling for "Generate" button ==================
204
+ buttonGenerate.click(predict,
205
+ inputs = [ sliderNumberOfECGs, sliderLengthInSeconds, dropdownType ],
206
+ outputs = [ outputGallery ]
207
+ )
208
+
209
+ # ====== Run on startup ==================================================
210
+ gui.load(predict,
211
+ inputs = [ sliderNumberOfECGs, sliderLengthInSeconds, dropdownType ],
212
+ outputs = [ outputGallery ]
213
+ )
214
+
215
+ # ====== Run the GUI ========================================================
216
+ if __name__ == "__main__":
217
+ gui.launch()
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- transformers
2
- torch
3
  ecg-plot
 
4
  matplotlib
5
  Pillow
6
- pydantic==2.10.6
 
 
 
 
1
  ecg-plot
2
+ gradio
3
  matplotlib
4
  Pillow
5
+ pydantic==2.10.6
6
+ torch