import gradio as gr
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
from constants import EMBEDDING_DIMENSION, LAUNCH_PARAMETERS, SUPPORTED_EMBEDDING_DIMENSIONS
from data import SAMPLE_SMILES
from service import MolecularEmbeddingService, SimilarMolecule, setup_logger
logger = setup_logger()
class App:
def __init__(self):
self.embedding_service = MolecularEmbeddingService()
self.demo = self.create_gradio_interface()
def molecule_similarity_search_pipeline(
self, smiles: str, embed_dim: int
) -> tuple[list[float], list[SimilarMolecule], str]:
"""Complete pipeline: SMILES -> Canonical SMILES -> Embedding -> Similar molecules"""
try:
if not smiles or smiles.strip() == "":
return [], [], "Please provide a valid SMILES string"
logger.info(f"Running similarity search: {smiles} - ({embed_dim})")
embedding = self.embedding_service.get_molecular_embedding(smiles, embed_dim)
neighbors = self.embedding_service.find_similar_molecules(embedding, embed_dim)
return embedding.tolist(), neighbors, "Search completed successfully"
except Exception as e:
error_msg = f"Search failed: {str(e)}"
logger.error(error_msg)
return [], [], error_msg
@staticmethod
def _truncated_attribute(obj, attr, max_len=45):
return f"{obj[attr][:max_len]}{'...' if len(obj[attr]) > max_len else ''}"
@classmethod
def _draw_molecule_grid(cls, similar: list[SimilarMolecule]) -> np.ndarray:
mols = [Chem.MolFromSmiles(m["smiles"]) for m in similar]
legends = [
f"{cls._truncated_attribute(m, 'name')}\n{m['properties']}\n"
f"{cls._truncated_attribute(m, 'smiles')}\n{m['score']:.2E}"
for m in similar
]
draw_options = rdMolDraw2D.MolDrawOptions()
draw_options.legendFontSize = 17
draw_options.legendFraction = 0.29
draw_options.drawMolsSameScale = False
img = Draw.MolsToGridImage(
mols,
legends=legends,
molsPerRow=3,
subImgSize=(250, 250),
drawOptions=draw_options,
)
return img
@staticmethod
def _display_sample_molecules(mols: pd.DataFrame):
for _, row in mols.iterrows():
with gr.Group():
gr.Textbox(
value=row["smiles"], label=f"{row['name']} ({row['properties']})", interactive=False, scale=3
)
sample_btn = gr.Button(
f"Load {row['name']}",
scale=1,
size="sm",
variant="primary",
)
sample_btn.click(
fn=None,
js=f"() => {{window.setCWSmiles('{row['smiles']}');}}",
)
@staticmethod
def clear_all():
return "", "", [], [], None, "Cleared - Draw a new molecule or enter SMILES"
def handle_search(self, smiles: str, embed_dim: int):
if not smiles.strip():
return (
[],
[],
None,
"Please draw a molecule or enter a SMILES string",
)
embedding, similar, status = self.molecule_similarity_search_pipeline(smiles, embed_dim)
img = self._draw_molecule_grid(similar)
return embedding, similar, img, status
def create_gradio_interface(self):
"""Create the Gradio interface optimized for JavaScript client usage"""
head_scripts = """
"""
with gr.Blocks(
title="Chem-MRL: Molecular Similarity Search Demo",
theme=gr.themes.Soft(), # type: ignore
head=head_scripts,
) as demo:
gr.Markdown("""
# ๐งช Chem-MRL: Molecular Similarity Search Demo
Use the ChemWriter editor to draw a molecule or input a SMILES string.
The backend encodes the molecule using the Chem-MRL model to produce a vector embedding.
Similarity search is performed via an HNSW-indexed Redis vector store to retrieve closest matches.
""")
gr.HTML(
"""
The Redis database indexes Isomer Design's molecular library.
""", # noqa: E501
padding=False,
)
gr.Markdown(
"[Model Repo](https://github.com/emapco/chem-mrl) | [Demo Repo](https://github.com/emapco/chem-mrl-demo)"
)
with gr.Tab("๐ฌ Molecular Search"), gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Molecule Input")
gr.HTML(
'