File size: 3,261 Bytes
4d799f2
 
e1e6b13
 
 
843425c
5465560
843425c
 
 
 
 
 
5465560
843425c
64428bf
 
 
 
 
 
5465560
214fccd
5465560
214fccd
 
 
 
 
 
 
 
5465560
214fccd
 
 
5465560
214fccd
 
5465560
214fccd
5465560
64428bf
 
5465560
64428bf
214fccd
5465560
214fccd
 
 
 
fe92162
5465560
4d799f2
5465560
4d799f2
 
 
ddae879
5465560
 
 
4d799f2
 
e1e6b13
4d799f2
5465560
 
214fccd
5465560
ddae879
4d799f2
5465560
 
4d799f2
5465560
214fccd
 
 
4d799f2
f63af71
64428bf
ddae879
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import sys
import json
import pandas as pd
import gradio as gr

# 1) Adjust path before importing the loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)

from smi_ted_light.load import load_smi_ted

# 2) Load the SMI-TED Light model
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
    folder=MODEL_DIR,
    ckpt_filename="smi-ted-Light_40.pt",
    vocab_filename="bert_vocab_curated.txt",
)

# 3) Single function to process either a single SMILES or a CSV of SMILES
def process_inputs(smiles: str, file_obj):
    # If a CSV file is provided, process in batch
    if file_obj is not None:
        try:
            df_in = pd.read_csv(file_obj.name)
            smiles_list = df_in.iloc[:, 0].astype(str).tolist()
            embeddings = []
            for sm in smiles_list:
                vec = model.encode(sm, return_torch=True)[0].tolist()
                embeddings.append(vec)
            # Build output DataFrame
            out_df = pd.DataFrame(embeddings)
            out_df.insert(0, "smiles", smiles_list)
            out_df.to_csv("embeddings.csv", index=False)
            msg = f"Processed batch of {len(smiles_list)} SMILES. Download embeddings.csv."
            return msg, gr.update(value="embeddings.csv", visible=True)
        except Exception as e:
            return f"Error processing batch: {e}", gr.update(visible=False)

    # Otherwise, process a single SMILES
    smiles = smiles.strip()
    if not smiles:
        return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
    try:
        vec = model.encode(smiles, return_torch=True)[0].tolist()
        # Save CSV with header
        cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
        df_out = pd.DataFrame([[smiles] + vec], columns=cols)
        df_out.to_csv("embeddings.csv", index=False)
        return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
    except Exception as e:
        return f"Error generating embedding: {e}", gr.update(visible=False)

# 4) Build the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # SMI-TED Embedding Generator  
        **Single mode:** paste a SMILES string in the left box.  
        **Batch mode:** upload a CSV file where each row has a SMILES in the first column.  
        In both cases, an `embeddings.csv` file will be generated for download, with the first column as SMILES and the embedding values in the following columns.
        """
    )

    with gr.Row():
        smiles_in = gr.Textbox(label="SMILES (single mode)", placeholder="e.g. CCO")
        file_in   = gr.File(label="SMILES CSV (batch mode)", file_types=[".csv"])

    generate_btn = gr.Button("Generate Embeddings")

    with gr.Row():
        output_msg   = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=2)
        download_csv = gr.File(label="Download embeddings.csv", visible=False)

    generate_btn.click(
        fn=process_inputs,
        inputs=[smiles_in, file_in],
        outputs=[output_msg, download_csv]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")