#!/usr/bin/env python # -*- coding: utf-8 -*- # ========================================================================== # ____ __ _ _____ ____ ____ # | _ \ ___ ___ _ __ / _| __ _| | _____ | ____/ ___/ ___| # | | | |/ _ \/ _ \ '_ \| |_ / _` | |/ / _ \ | _|| | | | _ # | |_| | __/ __/ |_) | _| (_| | < __/ | |__| |__| |_| | # |____/ \___|\___| .__/|_| \__,_|_|\_\___| |_____\____\____| # |_| # # --- Deepfake ECG Generator --- # https://github.com/vlbthambawita/deepfake-ecg # ========================================================================== # # DeepfakeECG GUI Application # Copyright (C) 2023-2025 by Vajira Thambawita # Copyright (C) 2025 by Thomas Dreibholz # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Contact: # * Vajira Thambawita # * Thomas Dreibholz import datetime import deepfakeecg import ecg_plot import gradio import io import matplotlib.pyplot as plt import matplotlib.ticker import numpy import pathlib import random import sys import tempfile import threading import torch import typing import PIL TempDirectory = None Sessions = {} # ###### Print log message ################################################## def log(logstring): print(('\x1b[34m' + datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') + ': ' + logstring + '\x1b[0m')); # ###### DeepFakeECG Plus Session (session with web browser) ################ class Session: # ###### Constructor ##################################################### def __init__(self): self.Lock = threading.Lock() self.Counter = 0 self.Selected = 0 self.Results = None self.Type = None self.TempDirectory = tempfile.TemporaryDirectory(dir = TempDirectory.name) log(f'Prepared temporary directory {self.TempDirectory.name}') # ###### Destructor ###################################################### def __del__(self): log(f'Cleaning up temporary directory {self.TempDirectory.name}') self.TempDirectory.cleanup() # ###### Increment counter ############################################### def increment(self): with self.lock: self.counter += 1 return self.counter # ###### Increment counter ############################################### def increment(self): with self.lock: self.counter += 1 return self.counter # ###### Initialize a new session ########################################### def initializeSession(request: gradio.Request): Sessions[request.session_hash] = Session() log(f'Session "{request.session_hash}" initialized') # ###### Clean up a session ################################################# def cleanUpSession(request: gradio.Request): if request.session_hash in Sessions: del instances[request.session_hash] log(f'Session "{request.session_hash}" cleaned up') # ###### Increment counter in session ####################################### def incrementCounter(request: gradio.Request): if request.session_hash in Sessions: instance = Sessions[request.session_hash] return instance.increment() log(f'ERROR: Session "{request.session_hash}" is not initialized!') # ###### Generate ECGs ###################################################### def predict(numberOfECGs: int = 1, # ecgLengthInSeconds: int = 10, ecgTypeString: str = 'ECG-12', generatorModel: str = 'Default', request: gradio.Request = None) -> list: ecgLengthInSeconds = 10 log(f'Session "{request.session_hash}": Generate EGCs!') # ====== Set ECG type ==================================================== ecgType = deepfakeecg.DATA_ECG12 if ecgTypeString == 'ECG-8': ecgType = deepfakeecg.DATA_ECG8 elif ecgTypeString == 'ECG-12': ecgType = deepfakeecg.DATA_ECG12 else: sys.stderr.write(f'WARNING: Invalid ecgTypeString {ecgTypeString}, using ECG-12!\n') # ====== Raise Locator.MAXTICKS, if necessary ============================ matplotlib.ticker.Locator.MAXTICKS = \ max(1000, ecgLengthInSeconds * deepfakeecg.ECG_SAMPLING_RATE) # print(matplotlib.ticker.Locator.MAXTICKS) # ====== Generate the ECGs =============================================== Sessions[request.session_hash].Results = \ deepfakeecg.generateDeepfakeECGs(numberOfECGs, ecgType = ecgType, ecgLengthInSeconds = ecgLengthInSeconds, ecgScaleFactor = deepfakeecg.ECG_DEFAULT_SCALE_FACTOR, outputFormat = deepfakeecg.OUTPUT_TENSOR, showProgress = False, runOnDevice = runOnDevice) Sessions[request.session_hash].Type = ecgType # ====== Create a list of image/label tuples for gradio.Gallery ========== plotList = [] ecgNumber = 1 info = '25 mm/sec, 1 mV/10 mm' for result in Sessions[request.session_hash].Results: # ====== Plot ECG ===================================================== # 1. Convert to NumPy # 2. Remove the Timestamp column (0) # 3. Convert from µV to mV result = result.t().detach().cpu().numpy()[1:] / 1000 # print(result) # ------ ECG-12 ------------------------------------------------------- if ecgType == deepfakeecg.DATA_ECG12: ecg_plot.plot(result, title = 'ECG-12 – ' + info, sample_rate = deepfakeecg.ECG_SAMPLING_RATE, lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'III', 'aVR', 'aVL', 'aVF' ], lead_order = [0, 1, 8, 9, 10, 11, 2, 3, 4, 5, 6, 7], show_grid = True) # ------ ECG-8 -------------------------------------------------------- else: ecg_plot.plot(result, title = 'ECG-8 – ' + info, sample_rate = deepfakeecg.ECG_SAMPLING_RATE, lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6' ], lead_order = [0, 1, 2, 3, 4, 5, 6, 7], show_grid = True) # ====== Generate WebP output ========================================= imageBuffer = io.BytesIO() plt.savefig(imageBuffer, format = 'webp') plt.close() image = PIL.Image.open(imageBuffer) plotList.append( (image, f'ECG Number {ecgNumber}') ) ecgNumber = ecgNumber + 1 return plotList # ###### Select ECG in the gallery ########################################## def select(event: gradio.SelectData, request: gradio.Request): # Get selection index from Gallery select() event: # https://github.com/gradio-app/gradio/issues/1976#issuecomment-1726018500 Sessions[request.session_hash].Selected = event.index log(f'Session "{request.session_hash}": Selected ECG #{Sessions[request.session_hash].Selected + 1}') # ###### Download CSV ####################################################### def downloadCSV(request: gradio.Request) -> pathlib.Path: ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected] ecgType = Sessions[request.session_hash].Type fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \ ('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.csv') deepfakeecg.dataToCSV(ecgResult, ecgType, fileName) log(f'Session "{request.session_hash}": Download CSV file {fileName}') return fileName # ###### Download PDF ####################################################### def downloadPDF(request: gradio.Request): ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected] ecgType = Sessions[request.session_hash].Type fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \ ('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.pdf') if ecgType == deepfakeecg.DATA_ECG12: outputLeads = [ 'I', 'II', 'III', 'aVL', 'aVR', 'aVF', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ] else: outputLeads = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ] deepfakeecg.dataToPDF(ecgResult, ecgType, outputLeads, fileName, Sessions[request.session_hash].Selected + 1) log(f'Session "{request.session_hash}": Download PDF file {fileName}') return fileName # ###### Analyze the selected ECG ########################################### def analyze(request: gradio.Request): log(f'Session "{request.session_hash}": Analyze ECG #{Sessions[request.session_hash].Selected + 1}!') data = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected] print(data) # ###### Main program ####################################################### # ====== Initialise ========================================================= runOnDevice: typing.Literal['cpu', 'cuda'] = 'cuda' if torch.cuda.is_available() else 'cpu' css = r""" div { background-image: url("https://www.nntb.no/~dreibh/graphics/backgrounds/background-essen.png"); } /* ###### General Settings ############################################## */ html, body { height: 100%; padding: 0; margin: 0; font-family: sans-serif; font-size: small; background-color: #E3E3E3; /* Simula background colour: #E3E3E3 */ } /* ###### Header ######################################################## */ div.header { background-image: none; background-color: #F15D22; /* Simula header colour: #F15D22 */ height: 7.5%; display: flex; justify-content: space-between; } div.logo-left { width: 12.5%; float: left; display: flex; padding: 0% 1%; align-items: center; background: white; } div.logo-right { width: 12.5%; float: right; display: flex; padding: 0% 1%; align-items: center; background: white; } div.title { display: flex; align-items: center; padding: 0% 1%; background-image: none; background-color: #F15D22; /* Simula header colour: #F15D22 */ font-family: "Ubuntu", sans-serif; font-size: 4vh; font-weight: bold; }r img.logo-image { max-width: 100%; max-height: 100%; } """ # ====== Create GUI ========================================================= with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue)) as gui: # ====== Session handling ================================================ # Session initialization, to be called when page is loaded gui.load(initializeSession) # Session clean-up, to be called when page is closed/refreshed gui.unload(cleanUpSession) # ====== Header ========================================================== big_block = gradio.HTML("""
SimulaMet
SEARCH Fake ECG Generator
NorNet
""") gradio.Markdown('## Settings') with gradio.Row(): sliderNumberOfECGs = gradio.Slider(1, 100, label="Number of ECGs", step = 1, value = 4, interactive = True) # sliderLengthInSeconds = gradio.Slider(5, 60, label="Length (s)", step = 5, value = 10, interactive = True) dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True) dropdownGeneratorModel = gradio.Dropdown( [ 'Default' ], label = 'Generator Model', interactive = True) with gradio.Column(): buttonGenerate = gradio.Button("Generate ECGs!") buttonAnalyze = gradio.Button("Analyze this ECG!") with gradio.Row(): buttonCSV = gradio.DownloadButton("Download CSV") buttonCSV_hidden = gradio.DownloadButton(visible=False, elem_id="download_csv_hidden") buttonPDF = gradio.DownloadButton("Download PDF") buttonPDF_hidden = gradio.DownloadButton(visible=False, elem_id="download_pdf_hidden") gradio.Markdown('## Output') with gradio.Row(): outputGallery = gradio.Gallery(label = 'output', columns = [ 1 ], height = 'auto', show_label = True, preview = True) outputGallery.select(select) gradio.Markdown('## Analysis') # ====== Add click event handling for "Generate" button ================== buttonGenerate.click(predict, inputs = [ sliderNumberOfECGs, # sliderLengthInSeconds, dropdownType, dropdownGeneratorModel ], outputs = [ outputGallery ] ) # ====== Add click event handling for "Analyze" button =================== buttonAnalyze.click(analyze) # ====== Add click event handling for download buttons =================== # Using hidden button and JavaScript, to generate download file on-the-fly: # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634 buttonCSV.click(downloadCSV) buttonCSV.click(fn = downloadCSV, inputs = None, outputs = [ buttonCSV_hidden ]).then( fn = None, inputs = None, outputs = None, js = "() => document.querySelector('#download_csv_hidden').click()") buttonPDF.click(downloadPDF) buttonPDF.click(fn = downloadPDF, inputs = None, outputs = [ buttonPDF_hidden ]).then( fn = None, inputs = None, outputs = None, js = "() => document.querySelector('#download_pdf_hidden').click()") # ====== Run on startup ================================================== gui.load(predict, inputs = [ sliderNumberOfECGs, # sliderLengthInSeconds, dropdownType, dropdownGeneratorModel ], outputs = [ outputGallery ] ) # ====== Run the GUI ======================================================== if __name__ == "__main__": # ------ Prepare temporary directory ------------------------------------- TempDirectory = tempfile.TemporaryDirectory(prefix = 'DeepFakeECGPlus-') log(f'Prepared temporary directory {TempDirectory.name}') # ------ Run the GUI, with downloads from temporary directory allowed ---- gui.launch(allowed_paths = [ TempDirectory.name ]) # ------ Clean up -------------------------------------------------------- log(f'Cleaning up temporary directory {TempDirectory.name}') TempDirectory.cleanup() log('Done!')