Spaces:
Running
Running
File size: 6,522 Bytes
8adb95d 10b1be2 8adb95d 10b1be2 8adb95d 59e713c 10b1be2 4fab10d 10b1be2 551d6ff 10b1be2 ba0c8e7 fb25f9c c5e4478 b7d6da0 59e713c fb25f9c c5e4478 10b1be2 abe6512 39579b4 abe6512 853bf92 39579b4 853bf92 8adb95d 59e713c 660f406 59e713c 660f406 8adb95d 10b1be2 083f3dd 1365361 59e713c 8adb95d fef897d c5e4478 8adb95d 10b1be2 c5e4478 59e713c 8adb95d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
import json
import gradio as gr
from gradio_moleculeview import moleculeview
import cellscape
def predict(input_mol, style, contour_level, view_str, chains):
# write view to file
with open("view_matrix", "w") as f:
f.write(json.loads(view_str))
chain_str = ""
chain_dict = json.loads(chains)
# sort keys in dict and add colors to chain_str
for chain in sorted(chain_dict.keys()):
chain_str += f" '{chain_dict[chain]}'"
if style == "Goodsell3D":
os.system(f"cellscape cartoon --pdb {input_mol.name} --outline residue --color_by chain --depth_shading --depth_lines --colors {chain_str} --depth flat --back_outline --view view_matrix --save outline_all.svg")
elif style == "Contour":
os.system(f"cellscape cartoon --pdb {input_mol.name} --outline chain --color_by chain --depth_contour_interval {contour_level} --colors {chain_str} --depth contours --back_outline --view view_matrix --save outline_all.svg")
else:
os.system(f"cellscape cartoon --pdb {input_mol.name} --outline chain --colors {chain_str} --depth flat --back_outline --view view_matrix --save outline_all.svg")
#read content of file
print(os.stat("outline_all.svg").st_size / (1024 * 1024))
os.system("inkscape outline_all.svg --actions='select-all;path-simplify;export-plain-svg' --export-filename pdb_opt.svg")
print(os.stat("outline_all.svg").st_size / (1024 * 1024))
html_output = """
<button id="copySvgBtn">Copy SVG to Clipboard</button>
<button id="copyPngBtn">Copy PNG to Clipboard</button>
<!-- Buttons for Download -->
<button id="downloadSvgBtn">Download SVG</button>
<button id="downloadPngBtn">Download PNG</button>
<script>
function copySvgToClipboard() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const blob = new Blob([svgData], { type: 'image/svg+xml' });
const clipboardItem = [new ClipboardItem({ 'image/svg+xml': blob })];
navigator.clipboard.write(clipboardItem).then(() => {
alert("SVG copied to clipboard!");
}).catch(err => {
console.error("Could not copy SVG to clipboard: ", err);
});
}
// Function to convert SVG to PNG and copy to clipboard
function copyPngToClipboard() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const img = new Image();
img.onload = function () {
canvas.width = svgElement.clientWidth;
canvas.height = svgElement.clientHeight;
ctx.drawImage(img, 0, 0);
canvas.toBlob(blob => {
const clipboardItem = [new ClipboardItem({ 'image/png': blob })];
navigator.clipboard.write(clipboardItem).then(() => {
alert("PNG copied to clipboard!");
}).catch(err => {
console.error("Could not copy PNG to clipboard: ", err);
});
}, 'image/png');
};
img.src = 'data:image/svg+xml;base64,' + btoa(svgData);
}
// Function to download SVG
function downloadSvg() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const blob = new Blob([svgData], { type: 'image/svg+xml' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'image.svg';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
// Function to download PNG
function downloadPng() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const img = new Image();
img.onload = function () {
canvas.width = svgElement.clientWidth;
canvas.height = svgElement.clientHeight;
ctx.drawImage(img, 0, 0);
canvas.toBlob(blob => {
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'image.png';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}, 'image/png');
};
img.src = 'data:image/svg+xml;base64,' + btoa(svgData);
}
// Button event listeners
document.getElementById('copySvgBtn').addEventListener('click', copySvgToClipboard);
document.getElementById('copyPngBtn').addEventListener('click', copyPngToClipboard);
document.getElementById('downloadSvgBtn').addEventListener('click', downloadSvg);
document.getElementById('downloadPngBtn').addEventListener('click', downloadPng);
</script>
"""
with open("pdb_opt.svg", "r") as f:
return f.read().replace("<svg", "<svg id='svgElement'")+html_output, "pdb_opt.svg"
def show_contour_level(style):
if style=="Contour":
return gr.Slider(minimum=1,maximum=50,step=1, value=10, label="Contour level", visible=True)
else:
return gr.Slider(minimum=1,maximum=50,step=1, value=10, label="Contour level", visible=False)
with gr.Blocks() as demo:
gr.Markdown("# PDB2Vector")
style = gr.Radio(value="Flat", choices=["Flat", "Contour", "Goodsell3D"], label="Style")
contour_level = gr.Slider(minimum=1,maximum=50,step=1, value=10, label="Contour level", visible=False)
style.change(show_contour_level, style, contour_level)
inp = moleculeview(label="Molecule3D")
view_str = gr.Textbox("viewMatrixResult", label="View Matrix", visible=False)
chains = gr.Textbox("chainsResult", label="Chains", visible=False)
hidden_style = gr.Textbox(visible=False)
timestamp = gr.Textbox(visible=False)
btn = gr.Button("Vectorize")
html = gr.HTML("")
out_file = gr.File(label="Download SVG")
btn.click(None, style, [view_str, chains, hidden_style, timestamp], js="(style) => [document.getElementById('viewMatrixResult').value, document.getElementById('chains').value, style, Date.now()]") #
timestamp.change(predict, [inp, style, contour_level, view_str, chains], [html, out_file])
# on change of chains trigger, rendering
if __name__ == "__main__":
demo.launch()
|