Spaces:
Sleeping
Sleeping
File size: 1,175 Bytes
953f112 58b95fb 7bbfc34 3c8cf6c b5953e8 58b95fb 3c8cf6c b8a864b 25aa2ba 953f112 b5953e8 953f112 b5953e8 953f112 b5953e8 3c8cf6c b5953e8 f89f767 953f112 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import io
import gradio as gr
from transformers import AutoModel
import ecg_plot
import matplotlib.pyplot as plt
from PIL import Image
import torch
#pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True)
def predict():
prediction = (model(1)[0].t()/1000) # to micro volts
lead_III = (prediction[1] - prediction[0]).unsqueeze(dim=0)
lead_aVR = ((prediction[0] + prediction[1])*(-0.5)).unsqueeze(dim=0)
lead_aVL = (prediction[0] - prediction[1]* 0.5).unsqueeze(dim=0)
lead_aVF = (prediction[1] - prediction[0]* 0.5).unsqueeze(dim=0)
all = torch.cat((prediction, lead_III, lead_aVR, lead_aVL, lead_aVF), dim=0)
all_corrected = all[torch.tensor([0,1,8, 9, 10, 11, 2,3,4,5,6,7])]
ecg_plot.plot(all_corrected, sample_rate = 500, title = 'ECG 12')
#ecg_plot.show()
buf = io.BytesIO()
plt.savefig(buf, format="png")
img = Image.open(buf)
return img
gr.Interface(
predict,
inputs=None,
outputs="image",
title="Generating Fake ECGs",
).launch()
|