File size: 3,222 Bytes
4d799f2
 
e1e6b13
 
 
843425c
ddae879
843425c
 
 
 
 
 
214fccd
843425c
64428bf
 
 
 
 
 
214fccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64428bf
 
214fccd
64428bf
214fccd
 
 
 
 
 
fe92162
214fccd
4d799f2
214fccd
4d799f2
 
 
ddae879
214fccd
 
 
4d799f2
 
e1e6b13
4d799f2
214fccd
 
 
 
ddae879
4d799f2
214fccd
 
4d799f2
ddae879
214fccd
 
 
4d799f2
f63af71
64428bf
ddae879
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import sys
import json
import pandas as pd
import gradio as gr

# 1) Ajusta o path antes de importar o loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)

from smi_ted_light.load import load_smi_ted

# 2) Carrega o modelo
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
    folder=MODEL_DIR,
    ckpt_filename="smi-ted-Light_40.pt",
    vocab_filename="bert_vocab_curated.txt",
)

# 3) Função única para processar SMILES simples ou CSV de SMILES
def process_inputs(smiles: str, file_obj):
    # Se vier um arquivo CSV, processa em batch
    if file_obj is not None:
        try:
            df_in = pd.read_csv(file_obj.name)
            smiles_list = df_in.iloc[:, 0].astype(str).tolist()
            embeddings = []
            for sm in smiles_list:
                vec = model.encode(sm, return_torch=True)[0].tolist()
                embeddings.append(vec)
            # Monta DataFrame de saída
            out_df = pd.DataFrame(embeddings)
            out_df.insert(0, "smiles", smiles_list)
            out_df.to_csv("embeddings.csv", index=False)
            msg = f"Batch de {len(smiles_list)} SMILES processado. Baixe em embeddings.csv."
            return msg, gr.update(value="embeddings.csv", visible=True)
        except Exception as e:
            return f"Erro no batch: {e}", gr.update(visible=False)

    # Senão, processa SMILES único
    smiles = smiles.strip()
    if not smiles:
        return "Digite um SMILES ou envie um arquivo CSV.", gr.update(visible=False)
    try:
        vec = model.encode(smiles, return_torch=True)[0].tolist()
        # Salva CSV com cabeçalho
        cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
        df_out = pd.DataFrame([[smiles] + vec], columns=cols)
        df_out.to_csv("embeddings.csv", index=False)
        return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
    except Exception as e:
        return f"Erro ao gerar embedding: {e}", gr.update(visible=False)

# 4) Monta interface Blocks
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # SMI-TED Embedding Generator  
        **Modo único:** cole um SMILES na caixa à esquerda.  
        **Modo batch:** faça upload de um CSV com várias linhas de SMILES (eles devem estar na primeira coluna).  
        Em ambos os casos, será gerado um arquivo `embeddings.csv` para download, com a primeira coluna de SMILES e o embedding nas colunas seguintes.
        """
    )

    with gr.Row():
        smiles_in = gr.Textbox(label="SMILES (modo único)", placeholder="Ex.: CCO")
        file_in   = gr.File(label="CSV de SMILES (modo batch)", file_types=[".csv"])

    gerar_btn = gr.Button("Gerar Embeddings")

    with gr.Row():
        output_msg  = gr.Textbox(label="Resposta/Embedding (JSON)", interactive=False, lines=2)
        download_csv = gr.File(label="Baixar embeddings.csv", visible=False)

    gerar_btn.click(
        fn=process_inputs,
        inputs=[smiles_in, file_in],
        outputs=[output_msg, download_csv]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")