|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import datetime |
|
import deepfakeecg |
|
import ecg_plot |
|
import gradio |
|
import io |
|
import matplotlib.pyplot as plt |
|
import matplotlib.ticker |
|
import numpy |
|
import pathlib |
|
import random |
|
import sys |
|
import tempfile |
|
import threading |
|
import torch |
|
import typing |
|
import PIL |
|
|
|
|
|
TempDirectory = None |
|
Sessions = {} |
|
|
|
|
|
|
|
def log(logstring): |
|
print(('\x1b[34m' + datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') + |
|
': ' + logstring + '\x1b[0m')); |
|
|
|
|
|
|
|
class Session: |
|
|
|
|
|
def __init__(self): |
|
self.Lock = threading.Lock() |
|
self.Counter = 0 |
|
self.Selected = 0 |
|
self.Results = None |
|
self.Type = None |
|
self.TempDirectory = tempfile.TemporaryDirectory(dir = TempDirectory.name) |
|
log(f'Prepared temporary directory {self.TempDirectory.name}') |
|
|
|
|
|
def __del__(self): |
|
log(f'Cleaning up temporary directory {self.TempDirectory.name}') |
|
self.TempDirectory.cleanup() |
|
|
|
|
|
def increment(self): |
|
with self.lock: |
|
self.counter += 1 |
|
return self.counter |
|
|
|
|
|
def increment(self): |
|
with self.lock: |
|
self.counter += 1 |
|
return self.counter |
|
|
|
|
|
|
|
def initializeSession(request: gradio.Request): |
|
Sessions[request.session_hash] = Session() |
|
log(f'Session "{request.session_hash}" initialized') |
|
|
|
|
|
|
|
def cleanUpSession(request: gradio.Request): |
|
if request.session_hash in Sessions: |
|
del instances[request.session_hash] |
|
log(f'Session "{request.session_hash}" cleaned up') |
|
|
|
|
|
|
|
def incrementCounter(request: gradio.Request): |
|
if request.session_hash in Sessions: |
|
instance = Sessions[request.session_hash] |
|
return instance.increment() |
|
log(f'ERROR: Session "{request.session_hash}" is not initialized!') |
|
|
|
|
|
|
|
|
|
def predict(numberOfECGs: int = 1, |
|
|
|
ecgTypeString: str = 'ECG-12', |
|
generatorModel: str = 'Default', |
|
request: gradio.Request = None) -> list: |
|
|
|
ecgLengthInSeconds = 10 |
|
|
|
log(f'Session "{request.session_hash}": Generate EGCs!') |
|
|
|
|
|
|
|
ecgType = deepfakeecg.DATA_ECG12 |
|
if ecgTypeString == 'ECG-8': |
|
ecgType = deepfakeecg.DATA_ECG8 |
|
elif ecgTypeString == 'ECG-12': |
|
ecgType = deepfakeecg.DATA_ECG12 |
|
else: |
|
sys.stderr.write(f'WARNING: Invalid ecgTypeString {ecgTypeString}, using ECG-12!\n') |
|
|
|
|
|
matplotlib.ticker.Locator.MAXTICKS = \ |
|
max(1000, ecgLengthInSeconds * deepfakeecg.ECG_SAMPLING_RATE) |
|
|
|
|
|
|
|
Sessions[request.session_hash].Results = \ |
|
deepfakeecg.generateDeepfakeECGs(numberOfECGs, |
|
ecgType = ecgType, |
|
ecgLengthInSeconds = ecgLengthInSeconds, |
|
ecgScaleFactor = deepfakeecg.ECG_DEFAULT_SCALE_FACTOR, |
|
outputFormat = deepfakeecg.OUTPUT_TENSOR, |
|
showProgress = False, |
|
runOnDevice = runOnDevice) |
|
Sessions[request.session_hash].Type = ecgType |
|
|
|
|
|
plotList = [] |
|
ecgNumber = 1 |
|
info = '25 mm/sec, 1 mV/10 mm' |
|
for result in Sessions[request.session_hash].Results: |
|
|
|
|
|
|
|
|
|
|
|
result = result.t().detach().cpu().numpy()[1:] / 1000 |
|
|
|
|
|
|
|
if ecgType == deepfakeecg.DATA_ECG12: |
|
ecg_plot.plot(result, |
|
title = 'ECG-12 – ' + info, |
|
sample_rate = deepfakeecg.ECG_SAMPLING_RATE, |
|
lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'III', 'aVR', 'aVL', 'aVF' ], |
|
lead_order = [0, 1, 8, 9, 10, 11, 2, 3, 4, 5, 6, 7], |
|
show_grid = True) |
|
|
|
else: |
|
ecg_plot.plot(result, |
|
title = 'ECG-8 – ' + info, |
|
sample_rate = deepfakeecg.ECG_SAMPLING_RATE, |
|
lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6' ], |
|
lead_order = [0, 1, 2, 3, 4, 5, 6, 7], |
|
show_grid = True) |
|
|
|
|
|
imageBuffer = io.BytesIO() |
|
plt.savefig(imageBuffer, format = 'webp') |
|
plt.close() |
|
image = PIL.Image.open(imageBuffer) |
|
plotList.append( (image, f'ECG Number {ecgNumber}') ) |
|
|
|
ecgNumber = ecgNumber + 1 |
|
|
|
return plotList |
|
|
|
|
|
|
|
def select(event: gradio.SelectData, |
|
request: gradio.Request): |
|
|
|
|
|
|
|
Sessions[request.session_hash].Selected = event.index |
|
log(f'Session "{request.session_hash}": Selected ECG #{Sessions[request.session_hash].Selected + 1}') |
|
|
|
|
|
|
|
def downloadCSV(request: gradio.Request) -> pathlib.Path: |
|
|
|
ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected] |
|
ecgType = Sessions[request.session_hash].Type |
|
fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \ |
|
('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.csv') |
|
deepfakeecg.dataToCSV(ecgResult, ecgType, fileName) |
|
|
|
log(f'Session "{request.session_hash}": Download CSV file {fileName}') |
|
return fileName |
|
|
|
|
|
|
|
|
|
def downloadPDF(request: gradio.Request): |
|
|
|
ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected] |
|
ecgType = Sessions[request.session_hash].Type |
|
fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \ |
|
('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.pdf') |
|
if ecgType == deepfakeecg.DATA_ECG12: |
|
outputLeads = [ 'I', 'II', 'III', 'aVL', 'aVR', 'aVF', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ] |
|
else: |
|
outputLeads = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ] |
|
|
|
deepfakeecg.dataToPDF(ecgResult, ecgType, outputLeads, fileName, |
|
Sessions[request.session_hash].Selected + 1) |
|
|
|
log(f'Session "{request.session_hash}": Download PDF file {fileName}') |
|
return fileName |
|
|
|
|
|
|
|
def analyze(request: gradio.Request): |
|
|
|
log(f'Session "{request.session_hash}": Analyze ECG #{Sessions[request.session_hash].Selected + 1}!') |
|
|
|
data = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected] |
|
print(data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
runOnDevice: typing.Literal['cpu', 'cuda'] = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
css = r""" |
|
div { |
|
background-image: url("https://www.nntb.no/~dreibh/graphics/backgrounds/background-essen.png"); |
|
} |
|
|
|
/* ###### General Settings ############################################## */ |
|
html, body { |
|
height: 100%; |
|
padding: 0; |
|
margin: 0; |
|
font-family: sans-serif; |
|
font-size: small; |
|
background-color: #E3E3E3; /* Simula background colour: #E3E3E3 */ |
|
} |
|
|
|
|
|
/* ###### Header ######################################################## */ |
|
div.header { |
|
background-image: none; |
|
background-color: #F15D22; /* Simula header colour: #F15D22 */ |
|
height: 7.5%; |
|
display: flex; |
|
justify-content: space-between; |
|
} |
|
|
|
div.logo-left { |
|
width: 12.5%; |
|
float: left; |
|
display: flex; |
|
padding: 0% 1%; |
|
align-items: center; |
|
background: white; |
|
} |
|
|
|
div.logo-right { |
|
width: 12.5%; |
|
float: right; |
|
display: flex; |
|
padding: 0% 1%; |
|
align-items: center; |
|
background: white; |
|
} |
|
|
|
div.title { |
|
display: flex; |
|
align-items: center; |
|
padding: 0% 1%; |
|
background-image: none; |
|
background-color: #F15D22; /* Simula header colour: #F15D22 */ |
|
|
|
font-family: "Ubuntu", sans-serif; |
|
font-size: 4vh; |
|
font-weight: bold; |
|
}r |
|
|
|
img.logo-image { |
|
max-width: 100%; |
|
max-height: 100%; |
|
} |
|
""" |
|
|
|
|
|
|
|
with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue)) as gui: |
|
|
|
|
|
|
|
gui.load(initializeSession) |
|
|
|
gui.unload(cleanUpSession) |
|
|
|
|
|
big_block = gradio.HTML(""" |
|
<div class="header"> |
|
<div class="logo-left"> |
|
<img class="logo-image" src="" alt="SimulaMet" height="32" /> |
|
</div> |
|
<div class="title" id="title"><a href="https://ihi-search.eu/">SEARCH</a> Fake ECG Generator</div> |
|
<div class="logo-right"> |
|
<img class="logo-image" src="" alt="NorNet" height="64" /> |
|
</div> |
|
</div> |
|
""") |
|
gradio.Markdown('## Settings') |
|
with gradio.Row(): |
|
sliderNumberOfECGs = gradio.Slider(1, 100, label="Number of ECGs", step = 1, value = 4, interactive = True) |
|
|
|
dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True) |
|
dropdownGeneratorModel = gradio.Dropdown( [ 'Default' ], label = 'Generator Model', interactive = True) |
|
with gradio.Column(): |
|
buttonGenerate = gradio.Button("Generate ECGs!") |
|
buttonAnalyze = gradio.Button("Analyze this ECG!") |
|
with gradio.Row(): |
|
buttonCSV = gradio.DownloadButton("Download CSV") |
|
buttonCSV_hidden = gradio.DownloadButton(visible=False, elem_id="download_csv_hidden") |
|
buttonPDF = gradio.DownloadButton("Download PDF") |
|
buttonPDF_hidden = gradio.DownloadButton(visible=False, elem_id="download_pdf_hidden") |
|
gradio.Markdown('## Output') |
|
with gradio.Row(): |
|
outputGallery = gradio.Gallery(label = 'output', |
|
columns = [ 1 ], |
|
height = 'auto', |
|
show_label = True, |
|
preview = True) |
|
outputGallery.select(select) |
|
gradio.Markdown('## Analysis') |
|
|
|
|
|
buttonGenerate.click(predict, |
|
inputs = [ sliderNumberOfECGs, |
|
|
|
dropdownType, |
|
dropdownGeneratorModel ], |
|
outputs = [ outputGallery ] |
|
) |
|
|
|
|
|
buttonAnalyze.click(analyze) |
|
|
|
|
|
|
|
|
|
buttonCSV.click(downloadCSV) |
|
buttonCSV.click(fn = downloadCSV, inputs = None, outputs = [ buttonCSV_hidden ]).then( |
|
fn = None, inputs = None, outputs = None, |
|
js = "() => document.querySelector('#download_csv_hidden').click()") |
|
buttonPDF.click(downloadPDF) |
|
buttonPDF.click(fn = downloadPDF, inputs = None, outputs = [ buttonPDF_hidden ]).then( |
|
fn = None, inputs = None, outputs = None, |
|
js = "() => document.querySelector('#download_pdf_hidden').click()") |
|
|
|
|
|
gui.load(predict, |
|
inputs = [ sliderNumberOfECGs, |
|
|
|
dropdownType, |
|
dropdownGeneratorModel ], |
|
outputs = [ outputGallery ] |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
TempDirectory = tempfile.TemporaryDirectory(prefix = 'DeepFakeECGPlus-') |
|
log(f'Prepared temporary directory {TempDirectory.name}') |
|
|
|
|
|
gui.launch(allowed_paths = [ TempDirectory.name ]) |
|
|
|
|
|
log(f'Cleaning up temporary directory {TempDirectory.name}') |
|
TempDirectory.cleanup() |
|
log('Done!') |
|
|