File size: 3,032 Bytes
c5d6bef
 
 
 
 
c0cd7ac
c5d6bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0cd7ac
c5d6bef
 
 
 
 
 
 
 
 
 
 
 
 
 
7f124ce
 
 
97ef52e
7f124ce
 
 
97ef52e
 
 
 
7f124ce
c5d6bef
 
7f124ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from proteins_viz import *
import gradio as gr
import spaces

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

model = model.cuda()

model.esm = model.esm.half()

import torch

torch.backends.cuda.matmul.allow_tf32 = True

model.trunk.set_chunk_size(64)

@spaces.GPU(duration=120)
def fold_protein(test_protein):
    tokenized_input = tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)['input_ids']
    tokenized_input = tokenized_input.cuda()
    with torch.no_grad():
        output = model(tokenized_input)
    pdb = convert_outputs_to_pdb(output)
    with open("output_structure.pdb", "w") as f:
        f.write("".join(pdb))
    image = take_care("output_structure.pdb")
    return image

iface = gr.Interface(
    title="everything-ai-proteinfold",
    fn=fold_protein,
    inputs=gr.Textbox(
            label="Protein Sequence",
            info="Find sequences examples below, and complete examples with images at: https://github.com/AstraBert/proteinviz/tree/main/examples.md; if you input a sequence, you're gonna get the static image and the HTML file with the 3D model to explore and play with",
            lines=5,
            value=f"Paste or write amino-acidic sequence here",
        ),
    outputs=[gr.Image(label="Protein static image"), gr.File(label="Protein 3D model HTML")], 
    examples=[
        "MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKVKAHGKKVLGAFSDGLAHLDNLKGTFATLSELHCDKLHVDPENFRLLGNVLVCVLAHHFGKEFTPPVQAAYQKVVAGVANALAHKYH",
        "MTEYKLVVVGAGGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGPGCMSCKCVLS",
        "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG",
    ]
)

iface.launch(server_name="0.0.0.0", share=False)