dreibh's picture
Added ECG information (25 mm/sec, 1 mV/10 mm) to plot.
b62eb97 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ==========================================================================
# ____ __ _ _____ ____ ____
# | _ \ ___ ___ _ __ / _| __ _| | _____ | ____/ ___/ ___|
# | | | |/ _ \/ _ \ '_ \| |_ / _` | |/ / _ \ | _|| | | | _
# | |_| | __/ __/ |_) | _| (_| | < __/ | |__| |__| |_| |
# |____/ \___|\___| .__/|_| \__,_|_|\_\___| |_____\____\____|
# |_|
#
# --- Deepfake ECG Generator ---
# https://github.com/vlbthambawita/deepfake-ecg
# ==========================================================================
#
# DeepfakeECG GUI Application
# Copyright (C) 2023-2025 by Vajira Thambawita
# Copyright (C) 2025 by Thomas Dreibholz
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Contact:
# * Vajira Thambawita <[email protected]>
# * Thomas Dreibholz <[email protected]>
import datetime
import deepfakeecg
import ecg_plot
import gradio
import io
import matplotlib.pyplot as plt
import matplotlib.ticker
import numpy
import pathlib
import random
import sys
import tempfile
import threading
import torch
import typing
import PIL
TempDirectory = None
Sessions = {}
# ###### Print log message ##################################################
def log(logstring):
print(('\x1b[34m' + datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') +
': ' + logstring + '\x1b[0m'));
# ###### DeepFakeECG Plus Session (session with web browser) ################
class Session:
# ###### Constructor #####################################################
def __init__(self):
self.Lock = threading.Lock()
self.Counter = 0
self.Selected = 0
self.Results = None
self.Type = None
self.TempDirectory = tempfile.TemporaryDirectory(dir = TempDirectory.name)
log(f'Prepared temporary directory {self.TempDirectory.name}')
# ###### Destructor ######################################################
def __del__(self):
log(f'Cleaning up temporary directory {self.TempDirectory.name}')
self.TempDirectory.cleanup()
# ###### Increment counter ###############################################
def increment(self):
with self.lock:
self.counter += 1
return self.counter
# ###### Increment counter ###############################################
def increment(self):
with self.lock:
self.counter += 1
return self.counter
# ###### Initialize a new session ###########################################
def initializeSession(request: gradio.Request):
Sessions[request.session_hash] = Session()
log(f'Session "{request.session_hash}" initialized')
# ###### Clean up a session #################################################
def cleanUpSession(request: gradio.Request):
if request.session_hash in Sessions:
del instances[request.session_hash]
log(f'Session "{request.session_hash}" cleaned up')
# ###### Increment counter in session #######################################
def incrementCounter(request: gradio.Request):
if request.session_hash in Sessions:
instance = Sessions[request.session_hash]
return instance.increment()
log(f'ERROR: Session "{request.session_hash}" is not initialized!')
# ###### Generate ECGs ######################################################
def predict(numberOfECGs: int = 1,
# ecgLengthInSeconds: int = 10,
ecgTypeString: str = 'ECG-12',
generatorModel: str = 'Default',
request: gradio.Request = None) -> list:
ecgLengthInSeconds = 10
log(f'Session "{request.session_hash}": Generate EGCs!')
# ====== Set ECG type ====================================================
ecgType = deepfakeecg.DATA_ECG12
if ecgTypeString == 'ECG-8':
ecgType = deepfakeecg.DATA_ECG8
elif ecgTypeString == 'ECG-12':
ecgType = deepfakeecg.DATA_ECG12
else:
sys.stderr.write(f'WARNING: Invalid ecgTypeString {ecgTypeString}, using ECG-12!\n')
# ====== Raise Locator.MAXTICKS, if necessary ============================
matplotlib.ticker.Locator.MAXTICKS = \
max(1000, ecgLengthInSeconds * deepfakeecg.ECG_SAMPLING_RATE)
# print(matplotlib.ticker.Locator.MAXTICKS)
# ====== Generate the ECGs ===============================================
Sessions[request.session_hash].Results = \
deepfakeecg.generateDeepfakeECGs(numberOfECGs,
ecgType = ecgType,
ecgLengthInSeconds = ecgLengthInSeconds,
ecgScaleFactor = deepfakeecg.ECG_DEFAULT_SCALE_FACTOR,
outputFormat = deepfakeecg.OUTPUT_TENSOR,
showProgress = False,
runOnDevice = runOnDevice)
Sessions[request.session_hash].Type = ecgType
# ====== Create a list of image/label tuples for gradio.Gallery ==========
plotList = []
ecgNumber = 1
info = '25 mm/sec, 1 mV/10 mm'
for result in Sessions[request.session_hash].Results:
# ====== Plot ECG =====================================================
# 1. Convert to NumPy
# 2. Remove the Timestamp column (0)
# 3. Convert from µV to mV
result = result.t().detach().cpu().numpy()[1:] / 1000
# print(result)
# ------ ECG-12 -------------------------------------------------------
if ecgType == deepfakeecg.DATA_ECG12:
ecg_plot.plot(result,
title = 'ECG-12 – ' + info,
sample_rate = deepfakeecg.ECG_SAMPLING_RATE,
lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'III', 'aVR', 'aVL', 'aVF' ],
lead_order = [0, 1, 8, 9, 10, 11, 2, 3, 4, 5, 6, 7],
show_grid = True)
# ------ ECG-8 --------------------------------------------------------
else:
ecg_plot.plot(result,
title = 'ECG-8 – ' + info,
sample_rate = deepfakeecg.ECG_SAMPLING_RATE,
lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6' ],
lead_order = [0, 1, 2, 3, 4, 5, 6, 7],
show_grid = True)
# ====== Generate WebP output =========================================
imageBuffer = io.BytesIO()
plt.savefig(imageBuffer, format = 'webp')
plt.close()
image = PIL.Image.open(imageBuffer)
plotList.append( (image, f'ECG Number {ecgNumber}') )
ecgNumber = ecgNumber + 1
return plotList
# ###### Select ECG in the gallery ##########################################
def select(event: gradio.SelectData,
request: gradio.Request):
# Get selection index from Gallery select() event:
# https://github.com/gradio-app/gradio/issues/1976#issuecomment-1726018500
Sessions[request.session_hash].Selected = event.index
log(f'Session "{request.session_hash}": Selected ECG #{Sessions[request.session_hash].Selected + 1}')
# ###### Download CSV #######################################################
def downloadCSV(request: gradio.Request) -> pathlib.Path:
ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
ecgType = Sessions[request.session_hash].Type
fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \
('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.csv')
deepfakeecg.dataToCSV(ecgResult, ecgType, fileName)
log(f'Session "{request.session_hash}": Download CSV file {fileName}')
return fileName
# ###### Download PDF #######################################################
def downloadPDF(request: gradio.Request):
ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
ecgType = Sessions[request.session_hash].Type
fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \
('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.pdf')
if ecgType == deepfakeecg.DATA_ECG12:
outputLeads = [ 'I', 'II', 'III', 'aVL', 'aVR', 'aVF', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ]
else:
outputLeads = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ]
deepfakeecg.dataToPDF(ecgResult, ecgType, outputLeads, fileName,
Sessions[request.session_hash].Selected + 1)
log(f'Session "{request.session_hash}": Download PDF file {fileName}')
return fileName
# ###### Analyze the selected ECG ###########################################
def analyze(request: gradio.Request):
log(f'Session "{request.session_hash}": Analyze ECG #{Sessions[request.session_hash].Selected + 1}!')
data = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
print(data)
# ###### Main program #######################################################
# ====== Initialise =========================================================
runOnDevice: typing.Literal['cpu', 'cuda'] = 'cuda' if torch.cuda.is_available() else 'cpu'
css = r"""
div {
background-image: url("https://www.nntb.no/~dreibh/graphics/backgrounds/background-essen.png");
}
/* ###### General Settings ############################################## */
html, body {
height: 100%;
padding: 0;
margin: 0;
font-family: sans-serif;
font-size: small;
background-color: #E3E3E3; /* Simula background colour: #E3E3E3 */
}
/* ###### Header ######################################################## */
div.header {
background-image: none;
background-color: #F15D22; /* Simula header colour: #F15D22 */
height: 7.5%;
display: flex;
justify-content: space-between;
}
div.logo-left {
width: 12.5%;
float: left;
display: flex;
padding: 0% 1%;
align-items: center;
background: white;
}
div.logo-right {
width: 12.5%;
float: right;
display: flex;
padding: 0% 1%;
align-items: center;
background: white;
}
div.title {
display: flex;
align-items: center;
padding: 0% 1%;
background-image: none;
background-color: #F15D22; /* Simula header colour: #F15D22 */
font-family: "Ubuntu", sans-serif;
font-size: 4vh;
font-weight: bold;
}r
img.logo-image {
max-width: 100%;
max-height: 100%;
}
"""
# ====== Create GUI =========================================================
with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue)) as gui:
# ====== Session handling ================================================
# Session initialization, to be called when page is loaded
gui.load(initializeSession)
# Session clean-up, to be called when page is closed/refreshed
gui.unload(cleanUpSession)
# ====== Header ==========================================================
big_block = gradio.HTML("""
<div class="header">
<div class="logo-left">
<img class="logo-image" src="" alt="SimulaMet" height="32" />
</div>
<div class="title" id="title"><a href="https://ihi-search.eu/">SEARCH</a>&nbsp;Fake ECG Generator</div>
<div class="logo-right">
<img class="logo-image" src="" alt="NorNet" height="64" />
</div>
</div>
""")
gradio.Markdown('## Settings')
with gradio.Row():
sliderNumberOfECGs = gradio.Slider(1, 100, label="Number of ECGs", step = 1, value = 4, interactive = True)
# sliderLengthInSeconds = gradio.Slider(5, 60, label="Length (s)", step = 5, value = 10, interactive = True)
dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True)
dropdownGeneratorModel = gradio.Dropdown( [ 'Default' ], label = 'Generator Model', interactive = True)
with gradio.Column():
buttonGenerate = gradio.Button("Generate ECGs!")
buttonAnalyze = gradio.Button("Analyze this ECG!")
with gradio.Row():
buttonCSV = gradio.DownloadButton("Download CSV")
buttonCSV_hidden = gradio.DownloadButton(visible=False, elem_id="download_csv_hidden")
buttonPDF = gradio.DownloadButton("Download PDF")
buttonPDF_hidden = gradio.DownloadButton(visible=False, elem_id="download_pdf_hidden")
gradio.Markdown('## Output')
with gradio.Row():
outputGallery = gradio.Gallery(label = 'output',
columns = [ 1 ],
height = 'auto',
show_label = True,
preview = True)
outputGallery.select(select)
gradio.Markdown('## Analysis')
# ====== Add click event handling for "Generate" button ==================
buttonGenerate.click(predict,
inputs = [ sliderNumberOfECGs,
# sliderLengthInSeconds,
dropdownType,
dropdownGeneratorModel ],
outputs = [ outputGallery ]
)
# ====== Add click event handling for "Analyze" button ===================
buttonAnalyze.click(analyze)
# ====== Add click event handling for download buttons ===================
# Using hidden button and JavaScript, to generate download file on-the-fly:
# https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
buttonCSV.click(downloadCSV)
buttonCSV.click(fn = downloadCSV, inputs = None, outputs = [ buttonCSV_hidden ]).then(
fn = None, inputs = None, outputs = None,
js = "() => document.querySelector('#download_csv_hidden').click()")
buttonPDF.click(downloadPDF)
buttonPDF.click(fn = downloadPDF, inputs = None, outputs = [ buttonPDF_hidden ]).then(
fn = None, inputs = None, outputs = None,
js = "() => document.querySelector('#download_pdf_hidden').click()")
# ====== Run on startup ==================================================
gui.load(predict,
inputs = [ sliderNumberOfECGs,
# sliderLengthInSeconds,
dropdownType,
dropdownGeneratorModel ],
outputs = [ outputGallery ]
)
# ====== Run the GUI ========================================================
if __name__ == "__main__":
# ------ Prepare temporary directory -------------------------------------
TempDirectory = tempfile.TemporaryDirectory(prefix = 'DeepFakeECGPlus-')
log(f'Prepared temporary directory {TempDirectory.name}')
# ------ Run the GUI, with downloads from temporary directory allowed ----
gui.launch(allowed_paths = [ TempDirectory.name ])
# ------ Clean up --------------------------------------------------------
log(f'Cleaning up temporary directory {TempDirectory.name}')
TempDirectory.cleanup()
log('Done!')