Spaces:
Running
Running
File size: 3,261 Bytes
4d799f2 e1e6b13 843425c 5465560 843425c 5465560 843425c 64428bf 5465560 214fccd 5465560 214fccd 5465560 214fccd 5465560 214fccd 5465560 214fccd 5465560 64428bf 5465560 64428bf 214fccd 5465560 214fccd fe92162 5465560 4d799f2 5465560 4d799f2 ddae879 5465560 4d799f2 e1e6b13 4d799f2 5465560 214fccd 5465560 ddae879 4d799f2 5465560 4d799f2 5465560 214fccd 4d799f2 f63af71 64428bf ddae879 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import sys
import json
import pandas as pd
import gradio as gr
# 1) Adjust path before importing the loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)
from smi_ted_light.load import load_smi_ted
# 2) Load the SMI-TED Light model
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
folder=MODEL_DIR,
ckpt_filename="smi-ted-Light_40.pt",
vocab_filename="bert_vocab_curated.txt",
)
# 3) Single function to process either a single SMILES or a CSV of SMILES
def process_inputs(smiles: str, file_obj):
# If a CSV file is provided, process in batch
if file_obj is not None:
try:
df_in = pd.read_csv(file_obj.name)
smiles_list = df_in.iloc[:, 0].astype(str).tolist()
embeddings = []
for sm in smiles_list:
vec = model.encode(sm, return_torch=True)[0].tolist()
embeddings.append(vec)
# Build output DataFrame
out_df = pd.DataFrame(embeddings)
out_df.insert(0, "smiles", smiles_list)
out_df.to_csv("embeddings.csv", index=False)
msg = f"Processed batch of {len(smiles_list)} SMILES. Download embeddings.csv."
return msg, gr.update(value="embeddings.csv", visible=True)
except Exception as e:
return f"Error processing batch: {e}", gr.update(visible=False)
# Otherwise, process a single SMILES
smiles = smiles.strip()
if not smiles:
return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
try:
vec = model.encode(smiles, return_torch=True)[0].tolist()
# Save CSV with header
cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
df_out = pd.DataFrame([[smiles] + vec], columns=cols)
df_out.to_csv("embeddings.csv", index=False)
return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
except Exception as e:
return f"Error generating embedding: {e}", gr.update(visible=False)
# 4) Build the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# SMI-TED Embedding Generator
**Single mode:** paste a SMILES string in the left box.
**Batch mode:** upload a CSV file where each row has a SMILES in the first column.
In both cases, an `embeddings.csv` file will be generated for download, with the first column as SMILES and the embedding values in the following columns.
"""
)
with gr.Row():
smiles_in = gr.Textbox(label="SMILES (single mode)", placeholder="e.g. CCO")
file_in = gr.File(label="SMILES CSV (batch mode)", file_types=[".csv"])
generate_btn = gr.Button("Generate Embeddings")
with gr.Row():
output_msg = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=2)
download_csv = gr.File(label="Download embeddings.csv", visible=False)
generate_btn.click(
fn=process_inputs,
inputs=[smiles_in, file_in],
outputs=[output_msg, download_csv]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")
|